Skip to content

Commit 159b892

Browse files
Updated inferencer script to match EWB's
1 parent 78aa1c2 commit 159b892

1 file changed

Lines changed: 14 additions & 11 deletions

File tree

src/topicmodeling/inferencer.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# For a more complete script, see:
2+
# https://github.com/IntelCompH2020/EWB/blob/main/ewb-inferencer/src/core/inferencer/base/inferencer.py
3+
14
import argparse
25
import json
36
import os
@@ -108,10 +111,10 @@ def apply_model_editions(self, thetas32):
108111
thetas32 = np.delete(thetas32, tpcs[1:], 1)
109112
thetas32 = normalize(thetas32, axis=1, norm='l1')
110113
self._logger.info(thetas32.shape) # nodcs*ntopics
111-
doc_topics_file_npy = infer_path.joinpath("doc-topics.npy")
114+
doc_topics_file_npy = infer_path.joinpath("thetas_infer.npy")
112115
np.save(doc_topics_file_npy, thetas32)
113116

114-
return
117+
return thetas32
115118

116119
@abstractmethod
117120
def predict(self):
@@ -209,10 +212,10 @@ def predict(self):
209212
thetas32 = np.loadtxt(doc_topics_file, delimiter='\t',
210213
dtype=np.float32, usecols=cols)
211214

212-
super().apply_model_editions(thetas32)
213-
super().transform_inference_output(thetas32, 100)
215+
thetas32 = super().apply_model_editions(thetas32)
216+
thetas32_rpr = super().transform_inference_output(thetas32, 1000)
214217

215-
return
218+
return thetas32_rpr
216219

217220

218221
class SparkLDAInferencer(Inferencer):
@@ -279,10 +282,10 @@ def predict(self):
279282
thetas32 = np.asarray(
280283
avitm.get_doc_topic_distribution(ho_data))
281284

282-
super().apply_model_editions(thetas32)
283-
super().transform_inference_output(thetas32, 100)
285+
thetas32 = super().apply_model_editions(thetas32)
286+
thetas32_rpr = super().transform_inference_output(thetas32, 1000)
284287

285-
return
288+
return thetas32_rpr
286289

287290

288291
class CTMInferencer(Inferencer):
@@ -349,10 +352,10 @@ def predict(self):
349352
thetas32 = np.asarray(
350353
ctm.get_doc_topic_distribution(ho_data))
351354

352-
super().apply_model_editions(thetas32)
353-
super().transform_inference_output(thetas32, 100)
355+
thetas32 = super().apply_model_editions(thetas32)
356+
thetas32_rpr = super().transform_inference_output(thetas32, 1000)
354357

355-
return
358+
return thetas32_rpr
356359

357360

358361
##############################################################################

0 commit comments

Comments
 (0)