-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor.rs
More file actions
191 lines (163 loc) · 7.97 KB
/
tensor.rs
File metadata and controls
191 lines (163 loc) · 7.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
//! Модуль, определяющий `Tensor` — основную структуру данных в библиотеке.
use crate::core::autograd::{self, BackwardContext};
use crate::error::Result;
use ndarray::{ArrayD, IxDyn};
use std::cell::RefCell;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::rc::Rc;
/// Основная многомерная структура данных для всех операций.
///
/// `Tensor` является оберткой над `ndarray::ArrayD<f32>`, добавляющей
/// возможность автоматического вычисления градиентов (autograd).
///
/// Внутренние данные (`data`) и градиент (`grad`) обернуты в `Rc<RefCell<...>>`,
/// что позволяет иметь несколько "владельцев" одного тензора и изменять его
/// содержимое, даже если на него есть только иммутабельные ссылки. Это ключевой
/// механизм для построения динамического графа вычислений.
#[derive(Clone)]
pub struct Tensor {
/// Внутренние данные тензора. Доступны для чтения и записи через `RefCell`.
pub data: Rc<RefCell<ArrayD<f32>>>,
/// Градиент этого тензора. `None`, если `requires_grad` было `false`.
pub grad: Option<Rc<RefCell<ArrayD<f32>>>>,
/// Контекст для обратного распространения ошибки. `None` для "листовых" тензоров.
pub ctx: Option<Rc<BackwardContext>>,
}
impl Hash for Tensor {
fn hash<H: Hasher>(&self, state: &mut H) {
// Хэшируем по указателю на данные, чтобы уникально идентифицировать узел в графе.
Rc::as_ptr(&self.data).hash(state);
}
}
impl PartialEq for Tensor {
fn eq(&self, other: &Self) -> bool {
// Сравниваем указатели, а не значения. Два тензора равны, только если
// они являются одним и тем же узлом в графе вычислений.
Rc::ptr_eq(&self.data, &other.data)
}
}
impl Eq for Tensor {}
impl Tensor {
/// Создает новый `Tensor`.
///
/// # Аргументы
///
/// * `data` - `ndarray::ArrayD<f32>`, который будет храниться в тензоре.
/// * `requires_grad` - Если `true`, для этого тензора будет создан и
/// будет накапливаться градиент при обратном распространении ошибки.
pub fn new(data: ArrayD<f32>, requires_grad: bool) -> Self {
let grad = if requires_grad {
let shape = data.shape();
let grad_data = ArrayD::zeros(IxDyn(shape));
Some(Rc::new(RefCell::new(grad_data)))
} else {
None
};
Self {
data: Rc::new(RefCell::new(data)),
grad,
ctx: None,
}
}
/// Создает новый `Tensor`, заполненный нулями.
pub fn zeros(shape: &[usize], requires_grad: bool) -> Self {
let data = ArrayD::zeros(IxDyn(shape));
Self::new(data, requires_grad)
}
/// Создает новый `Tensor`, заполненный единицами.
pub fn ones(shape: &[usize], requires_grad: bool) -> Self {
let data = ArrayD::ones(IxDyn(shape));
Self::new(data, requires_grad)
}
/// Выполняет сложение. См. `ops::basic_ops`.
pub fn add(&self, other: &Tensor) -> Tensor {
self + other
}
/// Выполняет вычитание. См. `ops::basic_ops`.
pub fn sub(&self, other: &Tensor) -> Tensor {
self - other
}
/// Выполняет поэлементное умножение. См. `ops::basic_ops`.
pub fn mul(&self, other: &Tensor) -> Tensor {
self * other
}
/// Выполняет матричное умножение. См. `ops::matmul::dot_op`.
pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
crate::ops::matmul::dot_op(self, other)
}
/// Возводит каждый элемент тензора в степень. См. `ops::elementwise::powf_op`.
pub fn powf(&self, power: f32) -> Tensor {
crate::ops::elementwise::powf_op(self, power)
}
/// Суммирует все элементы тензора, возвращая скалярный тензор. См. `ops::reduction::sum_op`.
pub fn sum(&self) -> Tensor {
crate::ops::reduction::sum_op(self)
}
/// Применяет активацию ReLU. См. `ops::elementwise::relu_op`.
pub fn relu(&self) -> Tensor {
crate::ops::elementwise::relu_op(self)
}
/// Применяет активацию Sigmoid. См. `ops::elementwise::sigmoid_op`.
pub fn sigmoid(&self) -> Tensor {
crate::ops::elementwise::sigmoid_op(self)
}
/// Вычисляет натуральный логарифм каждого элемента. См. `ops::elementwise::log_op`.
pub fn log(&self) -> Tensor {
crate::ops::elementwise::log_op(self)
}
/// Применяет поэлементную экспоненту. См. `ops::elementwise::exp_op`.
pub fn exp(&self) -> Tensor {
crate::ops::elementwise::exp_op(self)
}
/// Выполняет операцию встраивания. См. `ops::embedding::embedding_op`.
pub fn embedding(&self, weights: &Tensor) -> Result<Tensor> {
crate::ops::embedding::embedding_op(self, weights)
}
/// Применяет Layer Normalization. См. `ops::norm::layernorm_op`.
pub fn layer_norm(&self, gamma: &Tensor, beta: &Tensor, epsilon: f32) -> Result<Tensor> {
crate::ops::norm::layernorm_op(self, gamma, beta, epsilon)
}
/// Применяет Softmax по последней оси. См. `ops::elementwise::softmax_op`.
pub fn softmax(&self) -> Tensor {
crate::ops::elementwise::softmax_op(self)
}
/// Транспонирует тензор. См. `ops::transform::transpose_op`.
pub fn transpose(&self, axis1: usize, axis2: usize) -> Result<Tensor> {
crate::ops::transform::transpose_op(self, axis1, axis2)
}
/// Изменяет форму тензора. См. `ops::transform::reshape_op`.
pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Tensor> {
crate::ops::transform::reshape_op(self, new_shape)
}
/// Вычисляет Cross-Entropy Loss. См. `ops::selection::sparse_cross_entropy_op`.
pub fn sparse_cross_entropy(&self, targets: &Tensor) -> Tensor {
crate::ops::selection::sparse_cross_entropy_op(self, targets)
}
/// Запускает обратное распространение ошибки, начиная с этого тензора.
///
/// Градиент самого тензора (`self.grad`) будет инициализирован единицами,
/// после чего градиенты будут рекурсивно вычислены для всех его "предков"
/// в графе вычислений, у которых `requires_grad` было `true`.
pub fn backward(&self) {
autograd::backward(self);
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let data = self.data.borrow();
let grad_str = if let Some(grad) = &self.grad {
format!("\n grad: \n{}", grad.borrow())
} else {
" grad: None".to_string()
};
write!(
f,
"Tensor {{\n shape: {:?},\n data: \n{},\n{}\n ctx: {:?}\n}}",
data.shape(),
data,
grad_str,
self.ctx
)
}
}