forked from OmerShubi/Reuters_1987_Classification
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
67 lines (49 loc) · 1.89 KB
/
main.py
File metadata and controls
67 lines (49 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import logging.config
import sklearn.metrics
from sklearn.preprocessing import MultiLabelBinarizer
import model
import pickleHelper
def main():
"""
:return:
"""
# Gets or creates a logger
logging.config.fileConfig('logging.conf')
logger = logging.getLogger(__name__)
logger.info("********** NEW RUN **********")
# *******Change train_on_dataset to True for small dataset ********
train_on_full_dataset = False
is_submission = False
if train_on_full_dataset:
train_data_dir = "Data/train_data"
test_data_dir = "Data/reuters_test_data"
else:
train_data_dir = "Data/reuters_train_data"
test_data_dir = "Data/reuters_test_data"
logger.info("Initiating training with data from '%s' directory", train_data_dir)
knn_model = model.Model(train_data_dir)
logger.info("Predicting testing with data with countries from '%s' directory", test_data_dir)
if is_submission:
predictions = knn_model.predict(test_data_dir, is_submission=True)
else:
predictions, reference = knn_model.predict(test_data_dir)
logger.info("Prediction complete")
pickleHelper.save_to_pickle("predictions", predictions)
# path_to_predictions = "Pickles/predictions-2019-08-15-1027.p"
# try:
# returned_predictions = pickleHelper.retrieve_from_pickle(path_to_predictions, "predictions")
# except FileNotFoundError:
# returned_predictions = knn_model.predict(test_data_dir)
# print(predictions)
# print(reference)
mlb = MultiLabelBinarizer()
r = mlb.fit_transform(reference)
p = mlb.transform(predictions)
try:
score = sklearn.metrics.f1_score(y_true=r, y_pred=p, average='macro')
print(score)
logger.info("The f1 score is: %s", score)
except ValueError as ex:
logger.error("result value is invalid: " + str(ex))
if __name__ == "__main__":
main()