|
| 1 | +# For a more complete script, see: |
| 2 | +# https://github.com/IntelCompH2020/EWB/blob/main/ewb-inferencer/src/core/inferencer/base/inferencer.py |
| 3 | + |
1 | 4 | import argparse |
2 | 5 | import json |
3 | 6 | import os |
@@ -108,10 +111,10 @@ def apply_model_editions(self, thetas32): |
108 | 111 | thetas32 = np.delete(thetas32, tpcs[1:], 1) |
109 | 112 | thetas32 = normalize(thetas32, axis=1, norm='l1') |
110 | 113 | self._logger.info(thetas32.shape) # nodcs*ntopics |
111 | | - doc_topics_file_npy = infer_path.joinpath("doc-topics.npy") |
| 114 | + doc_topics_file_npy = infer_path.joinpath("thetas_infer.npy") |
112 | 115 | np.save(doc_topics_file_npy, thetas32) |
113 | 116 |
|
114 | | - return |
| 117 | + return thetas32 |
115 | 118 |
|
116 | 119 | @abstractmethod |
117 | 120 | def predict(self): |
@@ -209,10 +212,10 @@ def predict(self): |
209 | 212 | thetas32 = np.loadtxt(doc_topics_file, delimiter='\t', |
210 | 213 | dtype=np.float32, usecols=cols) |
211 | 214 |
|
212 | | - super().apply_model_editions(thetas32) |
213 | | - super().transform_inference_output(thetas32, 100) |
| 215 | + thetas32 = super().apply_model_editions(thetas32) |
| 216 | + thetas32_rpr = super().transform_inference_output(thetas32, 1000) |
214 | 217 |
|
215 | | - return |
| 218 | + return thetas32_rpr |
216 | 219 |
|
217 | 220 |
|
218 | 221 | class SparkLDAInferencer(Inferencer): |
@@ -279,10 +282,10 @@ def predict(self): |
279 | 282 | thetas32 = np.asarray( |
280 | 283 | avitm.get_doc_topic_distribution(ho_data)) |
281 | 284 |
|
282 | | - super().apply_model_editions(thetas32) |
283 | | - super().transform_inference_output(thetas32, 100) |
| 285 | + thetas32 = super().apply_model_editions(thetas32) |
| 286 | + thetas32_rpr = super().transform_inference_output(thetas32, 1000) |
284 | 287 |
|
285 | | - return |
| 288 | + return thetas32_rpr |
286 | 289 |
|
287 | 290 |
|
288 | 291 | class CTMInferencer(Inferencer): |
@@ -349,10 +352,10 @@ def predict(self): |
349 | 352 | thetas32 = np.asarray( |
350 | 353 | ctm.get_doc_topic_distribution(ho_data)) |
351 | 354 |
|
352 | | - super().apply_model_editions(thetas32) |
353 | | - super().transform_inference_output(thetas32, 100) |
| 355 | + thetas32 = super().apply_model_editions(thetas32) |
| 356 | + thetas32_rpr = super().transform_inference_output(thetas32, 1000) |
354 | 357 |
|
355 | | - return |
| 358 | + return thetas32_rpr |
356 | 359 |
|
357 | 360 |
|
358 | 361 | ############################################################################## |
|
0 commit comments