-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfunctional.ts
More file actions
315 lines (254 loc) · 7.25 KB
/
functional.ts
File metadata and controls
315 lines (254 loc) · 7.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import { Tensor } from '../tensor';
import { createOperation } from './registry';
import { ArgumentType } from './base';
function generate_function(opname: string) {
return (...args: ArgumentType[]) => {
const operation = createOperation(opname);
return operation.forward(...args);
};
}
function generate_unary_function(opname: string) {
return (a: Tensor | number) => {
if (typeof a == 'number') {
a = new Tensor(a);
}
const operation = createOperation(opname);
return operation.forward(a);
};
}
function generate_binary_function(opname: string) {
return (a: Tensor | number, b: Tensor | number) => {
if (typeof a == 'number') {
a = new Tensor(a);
}
if (typeof b == 'number') {
b = new Tensor(b);
}
const operation = createOperation(opname);
return operation.forward(a, b);
};
}
// debug operations
/**
* @ignore
* Get left index in a binary function
*/
export const __left_index__ = generate_binary_function('__left_index__');
/**
* @ignore
* Get right index in a binary function
*/
export const __right_index__ = generate_binary_function('__right_index__');
// binary pointwise
/**
* Adds two tensors element-wise.
*/
export const add = generate_binary_function('add');
/**
* Subtracts the second tensor from the first tensor element-wise.
*/
export const sub = generate_binary_function('sub');
/**
* Multiplies two tensors element-wise.
*/
export const mul = generate_binary_function('mul');
/**
* Divides the first tensor by the second tensor element-wise.
*/
export const div = generate_binary_function('div');
/**
* Raises the first tensor to the power of the second tensor element-wise.
*/
export const pow = generate_binary_function('pow');
/**
* Computes the element-wise remainder of the division of the first tensor by the second tensor.
*/
export const fmod = generate_binary_function('fmod');
/**
* Returns the element-wise maximum of the two tensors.
*/
export const maximum = generate_binary_function('maximum');
/**
* Returns the element-wise minimum of the two tensors.
*/
export const minimum = generate_binary_function('minimum');
// unary pointwise
/**
* Computes the natural logarithm of the input tensor element-wise.
*/
export const log = generate_unary_function('log');
/**
* Computes the square root of the input tensor element-wise.
*/
export const sqrt = generate_unary_function('sqrt');
/**
* Computes the exponential of the input tensor element-wise.
*/
export const exp = generate_unary_function('exp');
/**
* Computes the square of the input tensor element-wise.
*/
export const square = generate_unary_function('square');
/**
* Computes the absolute value of the input tensor element-wise.
*/
export const abs = generate_unary_function('abs');
/**
* Computes the sign of the input tensor element-wise.
*/
export const sign = generate_unary_function('sign');
/**
* Negates the input tensor element-wise.
*/
export const neg = generate_unary_function('neg');
/**
* Computes the reciprocal of the input tensor element-wise.
*/
export const reciprocal = generate_unary_function('reciprocal');
/**
* Replaces NaN values in the input tensor with 0, positive infinity with a large finite number, and negative infinity with a small finite number.
*/
export const nan_to_num = generate_unary_function('nan_to_num');
/**
* Reshapes the input tensor to the given shape.
*/
export const reshape = generate_function('reshape');
/**
* Removes all dimensions of size 1 from the input tensor.
*/
export const squeeze = generate_function('squeeze');
/**
* Adds a dimension of size 1 to the input tensor at the given position.
*/
export const unsqueeze = generate_function('unsqueeze');
/**
* Expands the input tensor to the given shape.
*/
export const expand = generate_function('expand');
// trigonometric
/**
* Computes the sine of the input tensor element-wise.
*/
export const sin = generate_unary_function('sin');
/**
* Computes the cosine of the input tensor element-wise.
*/
export const cos = generate_unary_function('cos');
/**
* Computes the tangent of the input tensor element-wise.
*/
export const tan = generate_unary_function('tan');
// reduction
/**
* Computes the sum of the elements of the input tensor.
*/
export const sum = generate_function('sum');
/**
* Computes the mean of the elements of the input tensor.
*/
export const mean = generate_function('mean');
/**
* Computes the minimum of the elements of the input tensor.
*/
export const min = generate_function('min');
/**
* Computes the maximum of the elements of the input tensor.
*/
export const max = generate_function('max');
// linalg
/**
* Transposes the input tensor.
*/
export const transpose = generate_function('transpose');
/**
* Computes the matrix product of the two input tensors.
*/
export const matmul = generate_binary_function('matmul');
// comparison
/**
* Checks if the first tensor is less than the second tensor element-wise.
*/
export const lt = generate_binary_function('lt');
/**
* Checks if the first tensor is greater than the second tensor element-wise.
*/
export const gt = generate_binary_function('gt');
/**
* Checks if the first tensor is less than or equal to the second tensor element-wise.
*/
export const le = generate_binary_function('le');
/**
* Checks if the first tensor is greater than or equal to the second tensor element-wise.
*/
export const ge = generate_binary_function('ge');
/**
* Checks if the first tensor is equal to the second tensor element-wise.
*/
export const eq = generate_binary_function('eq');
/**
* Checks if the first tensor is not equal to the second tensor element-wise.
*/
export const ne = generate_binary_function('ne');
/**
* Checks if the two tensors are equal element-wise within a given tolerance.
*/
export function allclose(
a: Tensor,
b: Tensor,
rtol: number = 1e-5,
atol: number = 1e-8,
equal_nan: boolean = false
): boolean {
return a.allclose(b, rtol, atol, equal_nan);
}
/**
* Returns the number of elements in the input tensor.
*/
export function numel(a: Tensor): number {
return a.dataLength();
}
/**
* Flattens the input tensor.
*/
export function flatten(input: Tensor, start_dim: number = 0, end_dim: number = -1): Tensor {
return input.flatten(start_dim, end_dim);
}
/**
* Concatenates tensors along a given dimension.
*/
export function cat(tensors: Tensor[], dim: number = 0): Tensor {
const operation = createOperation('cat');
return operation.forward(tensors, dim);
}
/**
* Alias for {@link cat}.
*/
export const concatenate = cat;
/**
* Alias for {@link cat}.
*/
export const concat = cat;
/**
* Computes the softmax of the input tensor along the given dimension.
*/
export function softmax(input: Tensor, dim: number): Tensor {
const operation = createOperation('softmax');
return operation.forward(input, dim);
}
/**
* Clamps all elements in input tensor to the range [min, max].
*/
export function clamp(input: Tensor, min: number, max: number): Tensor {
const operation = createOperation('clamp');
return operation.forward(input, min, max);
}
/**
* Alias for {@link clamp}.
*/
export const clip = clamp;
/**
* Stack tensors along a new dimension.
*/
export function stack(tensors: Tensor[], dim: number = 0): Tensor {
return cat(tensors.map(t => t.unsqueeze(dim)), dim);
}