-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbenchmark.js
More file actions
135 lines (109 loc) · 4.85 KB
/
benchmark.js
File metadata and controls
135 lines (109 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// reasoningHeuristicAI/benchmark.js
// --- Версия, использующая правильный датасет для оценки ---
import fs from 'fs/promises';
import path from 'path';
import slmnet from './core/slmnet/slmnet.js';
import { HeuristicEngine } from './core/heuristic-engine.js';
import workerPool from './core/worker-pool.js';
const config = {
embedding_dim: 64,
hidden_dim: 128,
data_path: './data',
benchmark_steps: 10, // Используем мало шагов, т.к. датасет маленький
learning_rate: 0.001,
grad_clip: 1.0,
accumulation_steps: 1, // Без накопления, т.к. примеров мало
};
async function loadBenchmarkData() {
// --- ИСПРАВЛЕНИЕ: Используем тот же файл, что и для fine-tuning'а ---
const filePath = path.join(config.data_path, 'coder_finetune.jsonl');
let fullTextForTokenizer = '';
try {
const fileContent = await fs.readFile(filePath, 'utf-8');
const lines = fileContent.trim().split('\n');
const jsonData = lines.map(line => JSON.parse(line));
jsonData.forEach(item => {
fullTextForTokenizer += item.prompt + item.completion;
});
const tokenizer = new slmnet.tokenizers.CharacterTokenizer(fullTextForTokenizer);
const data = jsonData.map(item => ({
prompt_ids: tokenizer.encode(item.prompt),
completion_ids: tokenizer.encode(item.completion)
}));
return { tokenizer, data };
} catch (error) {
console.error(`Ошибка: Не удалось загрузить данные для бенчмарка из '${filePath}'.`);
throw error;
}
}
async function run() {
const startTime = Date.now();
const { tokenizer, data } = await loadBenchmarkData();
if (data.length === 0) {
throw new Error("Нет данных для бенчмарка.");
}
const modelConfig = {
vocab_size: tokenizer.vocab_size,
embedding_dim: config.embedding_dim,
hidden_dim: config.hidden_dim
};
const model = new HeuristicEngine(modelConfig);
const optimizer = new slmnet.optimizers.Adam(model.parameters(), config.learning_rate);
let total_loss = 0;
let steps_done = 0;
optimizer.zero_grad();
for (let i = 0; i < data.length && steps_done < config.benchmark_steps; i++) {
const sample = data[i];
const input_ids = sample.prompt_ids.concat(sample.completion_ids);
if (input_ids.length <= 1) continue;
const model_input_ids = input_ids.slice(0, -1);
const target_ids = input_ids.slice(1);
try {
const x = new slmnet.Tensor(model_input_ids, [1, model_input_ids.length]);
const logits = await model.forward(x);
const reshaped_logits = new slmnet.Tensor(logits.data, [logits.shape[1], logits.shape[2]]);
const reshaped_targets = new slmnet.Tensor(target_ids, [target_ids.length]);
const loss = await slmnet.losses.cross_entropy_loss(reshaped_logits, reshaped_targets);
if (isNaN(loss.data[0])) {
continue;
}
const current_loss_val = loss.data[0] / config.accumulation_steps;
total_loss += loss.data[0];
const loss_for_backward = new slmnet.Tensor([current_loss_val], [1]);
loss_for_backward.backward_fn = loss.backward_fn;
await loss_for_backward.backward();
if ((i + 1) % config.accumulation_steps === 0) {
for(const p of model.parameters()) {
if (p.grad) {
let norm = 0;
for(const val of p.grad.data) norm += val*val;
norm = Math.sqrt(norm);
if(norm > config.grad_clip) {
for(let j=0; j<p.grad.data.length; j++) p.grad.data[j] *= config.grad_clip/norm;
}
}
}
optimizer.step();
optimizer.zero_grad();
steps_done++;
}
} catch (error) {
// В бенчмарке мы просто игнорируем ошибки шагов
}
}
const endTime = Date.now();
const finalLoss = total_loss / (steps_done * config.accumulation_steps);
const totalTime = endTime - startTime;
console.log(`--- Benchmark Result ---`);
console.log(`Final Loss: ${finalLoss.toFixed(4)}`);
console.log(`Time: ${totalTime}ms`);
console.log(`------------------------`);
await workerPool.destroy();
}
run().catch(error => {
console.error("Критическая ошибка в бенчмарке:", error);
if (workerPool) {
workerPool.destroy();
}
process.exit(1);
});