-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.js
More file actions
238 lines (201 loc) · 10.8 KB
/
train.js
File metadata and controls
238 lines (201 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
// reasoningHeuristicAI/train.js
// --- Финальная, универсальная версия для обучения на датасетах любого размера ---
import fs from 'fs/promises';
import path from 'path';
import slmnet from './core/slmnet/slmnet.js';
import { HeuristicEngine } from './core/heuristic-engine.js';
import { saveModel } from './core/model-io.js';
import workerPool from './core/worker-pool.js';
const config = {
embedding_dim: 128,
hidden_dim: 256,
data_path: './data',
final_model_path: './models/rHAI_v1_final.json',
last_model_path: './models/rHAI_v1_last.json',
num_epochs: 15,
accumulation_steps: 4,
max_learning_rate: 0.001,
min_learning_rate: 0.0001,
grad_clip: 1.0,
validation_split: 0.1,
patience: 3,
validation_samples: 200,
};
function clipGrads(parameters, max_norm) {
let total_norm = 0;
for (const p of parameters) {
if (p.grad) {
for (const grad_val of p.grad.data) {
total_norm += grad_val * grad_val;
}
}
}
total_norm = Math.sqrt(total_norm);
if (total_norm > max_norm) {
const scale = max_norm / (total_norm + 1e-6);
for (const p of parameters) {
if (p.grad) {
for (let i = 0; i < p.grad.data.length; i++) {
p.grad.data[i] *= scale;
}
}
}
}
}
async function loadTrainingData() {
console.log(`[Trainer] Загрузка и подготовка данных...`);
const allFilesInDir = await fs.readdir(config.data_path);
const trainingFiles = allFilesInDir.filter(file => file.endsWith('.jsonl'));
if (trainingFiles.length === 0) throw new Error(`В директории ${config.data_path} не найдено .jsonl файлов.`);
let allJsonData = [];
let fullTextForTokenizer = '';
for (const file of trainingFiles) {
const filePath = path.join(config.data_path, file);
const fileContent = await fs.readFile(filePath, 'utf-8');
const lines = fileContent.trim().split('\n');
for (const line of lines) {
if (line) {
const jsonData = JSON.parse(line);
allJsonData.push(jsonData);
fullTextForTokenizer += (jsonData.prompt || '') + (jsonData.completion || '');
}
}
}
const tokenizer = new slmnet.tokenizers.CharacterTokenizer(fullTextForTokenizer);
const data = allJsonData.map(item => ({
prompt_ids: tokenizer.encode(item.prompt),
completion_ids: tokenizer.encode(item.completion)
}));
return { tokenizer, data };
}
async function evaluate(model, validation_data) {
if (validation_data.length === 0) return Infinity;
const num_samples = Math.min(config.validation_samples, validation_data.length);
let total_loss = 0;
let count = 0;
for (let i = 0; i < num_samples; i++) {
const sample = validation_data[Math.floor(Math.random() * validation_data.length)];
const input_ids = sample.prompt_ids.concat(sample.completion_ids);
if (input_ids.length <= 1) continue;
const model_input_ids = input_ids.slice(0, -1);
const target_ids = input_ids.slice(1);
const x = new slmnet.Tensor(model_input_ids, [1, model_input_ids.length]);
const logits = await model.forward(x);
const reshaped_logits = new slmnet.Tensor(logits.data, [logits.shape[1], logits.shape[2]]);
const reshaped_targets = new slmnet.Tensor(target_ids, [target_ids.length]);
const loss = await slmnet.losses.cross_entropy_loss(reshaped_logits, reshaped_targets);
if (!isNaN(loss.data[0])) {
total_loss += loss.data[0];
count++;
}
}
return count > 0 ? total_loss / count : Infinity;
}
function shuffleArray(array) {
for (let i = array.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[array[i], array[j]] = [array[j], array[i]];
}
}
async function main() {
console.log(`--- ЗАПУСК ОБУЧЕНИЯ rHAI с Наставником ---`);
const { tokenizer, data } = await loadTrainingData();
const split_index = Math.floor(data.length * (1 - config.validation_split));
const trainingData = data.slice(0, split_index);
const validationData = data.slice(split_index);
console.log(`[Наставник] Данные разделены: ${trainingData.length} для тренировки, ${validationData.length} для экзамена.`);
const modelConfig = {
vocab_size: tokenizer.vocab_size,
embedding_dim: config.embedding_dim,
hidden_dim: config.hidden_dim
};
const model = new HeuristicEngine(modelConfig);
const optimizer = new slmnet.optimizers.Adam(model.parameters(), config.max_learning_rate);
console.log(`[Trainer] Модель и данные готовы. Начинаю обучение...`);
const startTime = Date.now();
let best_validation_loss = Infinity;
let patience_counter = 0;
let total_steps = 0;
const steps_per_epoch = Math.ceil(trainingData.length / config.accumulation_steps);
const total_train_steps_for_lr = steps_per_epoch * config.num_epochs;
const validate_every = Math.max(50, Math.floor(steps_per_epoch / 4)); // Валидация 4 раза за эпоху, но не чаще чем раз в 50 шагов
for (let epoch = 1; epoch <= config.num_epochs; epoch++) {
console.log(`\n--- ЭПОХА ${epoch}/${config.num_epochs} ---`);
shuffleArray(trainingData);
let accumulated_loss = 0;
optimizer.zero_grad();
for (let i = 0; i < trainingData.length; i++) {
const sample = trainingData[i];
const input_ids = sample.prompt_ids.concat(sample.completion_ids);
if (input_ids.length <= 1) continue;
const model_input_ids = input_ids.slice(0, -1);
const target_ids = input_ids.slice(1);
try {
const x = new slmnet.Tensor(model_input_ids, [1, model_input_ids.length]);
const logits = await model.forward(x);
const reshaped_logits = new slmnet.Tensor(logits.data, [logits.shape[1], logits.shape[2]]);
const reshaped_targets = new slmnet.Tensor(target_ids, [target_ids.length]);
const loss = await slmnet.losses.cross_entropy_loss(reshaped_logits, reshaped_targets);
if (isNaN(loss.data[0])) continue;
const current_loss_val = loss.data[0] / config.accumulation_steps;
accumulated_loss += loss.data[0];
const loss_for_backward = new slmnet.Tensor([current_loss_val], [1]);
loss_for_backward.backward_fn = loss.backward_fn;
await loss_for_backward.backward();
if ((i + 1) % config.accumulation_steps === 0 || i === trainingData.length - 1) {
total_steps++;
const decay_ratio = Math.min(total_steps / (total_train_steps_for_lr * 0.8));
optimizer.lr = config.max_learning_rate - decay_ratio * (config.max_learning_rate - config.min_learning_rate);
optimizer.lr = Math.max(optimizer.lr, config.min_learning_rate);
clipGrads(model.parameters(), config.grad_clip);
optimizer.step();
optimizer.zero_grad();
if (total_steps % 50 === 0) {
const duration_ms = Date.now() - startTime;
const steps_per_sec = total_steps / (duration_ms / 1000);
console.log(`Эпоха ${epoch} | Шаг ${total_steps} | Прогресс: ${((i / trainingData.length) * 100).toFixed(1)}% | Ошибка: ${(accumulated_loss / config.accumulation_steps).toFixed(4)} | LR: ${optimizer.lr.toExponential(2)} | ${steps_per_sec.toFixed(2)} шаг/сек`);
}
accumulated_loss = 0;
if (total_steps > 0 && total_steps % validate_every === 0) {
const validation_loss = await evaluate(model, validationData);
console.log(`\n----------------------------------------------------`);
console.log(`[Наставник] Экзамен на шаге ${total_steps}: Валидационная Ошибка = ${validation_loss.toFixed(4)}`);
if (validation_loss < best_validation_loss) {
console.log(`[Наставник] ✅ Новый лучший результат! (было ${best_validation_loss === Infinity ? 'Infinity' : best_validation_loss.toFixed(4)}). Сохраняю модель...`);
best_validation_loss = validation_loss;
await saveModel(model, tokenizer, modelConfig, config.final_model_path);
patience_counter = 0;
} else {
patience_counter++;
console.log(`[Наставник] ⚠️ Результат не улучшился. Терпение: ${patience_counter}/${config.patience}`);
}
console.log(`----------------------------------------------------\n`);
if (patience_counter >= config.patience) {
console.log(`[Наставник] 🛑 Терпение исчерпано. Ранняя остановка обучения.`);
epoch = config.num_epochs + 1;
break;
}
}
}
} catch (error) {
console.error(`[Trainer] Критическая ошибка на шаге ${i}. Пропускаю. Ошибка: ${error.message}`);
optimizer.zero_grad();
accumulated_loss = 0;
}
}
}
console.log(`\n[Trainer] Обучение завершено. Сохраняю финальное состояние модели...`);
await saveModel(model, tokenizer, modelConfig, config.last_model_path);
if (best_validation_loss === Infinity && data.length > 0) {
console.log('[Trainer] Валидация не проводилась. Сохраняю последнюю модель как лучшую.');
await saveModel(model, tokenizer, modelConfig, config.final_model_path);
}
console.log(`[Trainer] Лучшая модель сохранена в: ${config.final_model_path}`);
await workerPool.destroy();
}
main().catch(error => {
console.error("Произошла критическая ошибка:", error);
if (workerPool) {
workerPool.destroy();
}
});