-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
149 lines (119 loc) · 5.68 KB
/
trainer.py
File metadata and controls
149 lines (119 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import pickle
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from bot import MODEL_CONFIG, BOT_PATHS
import logging
class ModelTrainer:
def __init__(self, text_processor, intent_manager):
self.text_processor = text_processor
self.intent_manager = intent_manager
self.models = {}
def prepare_training_data(self, language='en'):
patterns, labels = self.intent_manager.get_training_data(language)
if not patterns or not labels:
logging.error(f"No training data available for {language}")
return None, None
self.text_processor.set_language(language)
processed_patterns = [self.text_processor.preprocess(pattern) for pattern in patterns]
valid_data = [(pattern, label) for pattern, label in zip(processed_patterns, labels) if pattern.strip()]
if not valid_data:
logging.error(f"No valid training data after preprocessing for {language}")
return None, None
processed_patterns, labels = zip(*valid_data)
logging.info(f"Prepared {len(processed_patterns)} training samples for {language}")
return list(processed_patterns), list(labels)
def train_model(self, language='en'):
logging.info(f"Training model for language: {language}")
X, y = self.prepare_training_data(language)
if X is None or len(X) < 2:
logging.error(f"Insufficient training data for {language}")
return False
if len(X) > 4:
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=MODEL_CONFIG['test_size'],
random_state=MODEL_CONFIG['random_state'],
stratify=y if len(set(y)) > 1 else None
)
else:
X_train, X_test, y_train, y_test = X, X, y, y
logging.warning(f"Using same data for train and test due to small dataset size for {language}")
pipeline = Pipeline([
('tfidf', TfidfVectorizer(
max_features=MODEL_CONFIG['max_features'],
ngram_range=MODEL_CONFIG['ngram_range'],
lowercase=True,
stop_words=None
)),
('classifier', LogisticRegression(
random_state=MODEL_CONFIG['random_state'],
max_iter=MODEL_CONFIG['max_iter'],
class_weight='balanced'
))
])
try:
pipeline.fit(X_train, y_train)
except Exception as e:
logging.error(f"Failed to train model for {language}: {str(e)}")
return False
try:
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
logging.info(f"Model accuracy for {language}: {accuracy:.3f}")
logging.info(f"Classification Report for {language}:")
logging.info(f"\n{classification_report(y_test, y_pred, zero_division=0)}")
except Exception as e:
logging.warning(f"Could not evaluate model for {language}: {str(e)}")
self.models[language] = pipeline
model_file = os.path.join(BOT_PATHS['models_dir'], f'model_{language}.pkl')
try:
with open(model_file, 'wb') as file:
pickle.dump(pipeline, file)
logging.info(f"Model saved for {language}: {model_file}")
except Exception as e:
logging.error(f"Failed to save model for {language}: {str(e)}")
return False
return True
def train_all_languages(self):
available_languages = self.intent_manager.get_available_languages()
if not available_languages:
logging.error("No languages available for training")
return False
success_count = 0
for language in available_languages:
if self.train_model(language):
success_count += 1
logging.info(f"Successfully trained models for {success_count}/{len(available_languages)} languages")
return success_count > 0
def load_model(self, language='en'):
model_file = os.path.join(BOT_PATHS['models_dir'], f'model_{language}.pkl')
try:
with open(model_file, 'rb') as file:
self.models[language] = pickle.load(file)
logging.info(f"Model loaded for {language}")
return True
except FileNotFoundError:
logging.warning(f"Model file not found for {language}: {model_file}")
return False
except Exception as e:
logging.error(f"Failed to load model for {language}: {str(e)}")
return False
def load_all_models(self):
success_count = 0
for language in os.listdir(BOT_PATHS['models_dir']):
if language.startswith('model_') and language.endswith('.pkl'):
lang_code = language.replace('model_', '').replace('.pkl', '')
if self.load_model(lang_code):
success_count += 1
logging.info(f"Loaded {success_count} models")
return success_count > 0
def get_model(self, language='en'):
return self.models.get(language)
def is_model_available(self, language='en'):
return language in self.models
def get_available_models(self):
return list(self.models.keys())