diff --git a/pkg/autograd/grad_example_test.go b/pkg/autograd/grad_example_test.go index b85ab1a..c9148f6 100644 --- a/pkg/autograd/grad_example_test.go +++ b/pkg/autograd/grad_example_test.go @@ -1,10 +1,11 @@ package autograd import ( + "math" "testing" - "github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph" "github.com/Hirogava/Go-NN-Learn/pkg/tensor" + "github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph" ) // TestExampleDotSumGradient строит примитивный граф y = sum(a * b) и проверяет его градиента. @@ -44,3 +45,76 @@ func TestExampleDotSumGradient(t *testing.T) { } } } + +// численный градиент +func numericalGrad(f func(x *graph.Node) *graph.Node, x *tensor.Tensor, eps float64) *tensor.Tensor { + grad := tensor.Zeros(x.Shape...) + for i := range x.Data { + orig := x.Data[i] + + x.Data[i] = orig + eps + y1 := SumTensor(f(&graph.Node{Value: x}).Value) + + x.Data[i] = orig - eps + y2 := SumTensor(f(&graph.Node{Value: x}).Value) + + grad.Data[i] = (y1 - y2) / (2 * eps) + x.Data[i] = orig + } + return grad +} + +func SumTensor(t *tensor.Tensor) float64 { + s := 0.0 + for _, v := range t.Data { + s += v + } + return s +} + +// grad_check +func gradCheck(t *testing.T, f func(x *graph.Node) *graph.Node, x *tensor.Tensor, eps, tol float64) { + node := &graph.Node{Value: x} + node.Grad = nil + + // forward + backward + eng := NewEngine() + y := f(node) + eng.Backward(y) + + // численный градиент + numGrad := numericalGrad(f, x, eps) + + // сравнение + for i := range x.Data { + a := node.Grad.Data[i] + n := numGrad.Data[i] + if math.Abs(a-n) > tol { + t.Fatalf("grad check failed at index %d: analytic=%v, numeric=%v", i, a, n) + } + } +} + +// тест Reshape +func TestReshapeGradCheck(t *testing.T) { + x := tensor.Randn([]int{2, 3}, 42) // фиксированный seed + f := func(xn *graph.Node) *graph.Node { + eng := NewEngine() + return eng.Reshape(xn, []int{3, 2}) + } + + gradCheck(t, f, x, 1e-6, 1e-4) +} + +// тест Transpose +func TestTransposeGradCheck(t *testing.T) { + x := tensor.Randn([]int{2, 3}, 42) + f := func(xn *graph.Node) *graph.Node { + + eng := NewEngine() + y := eng.Transpose(xn) + return eng.Sum(y) // скалярный выход + } + + gradCheck(t, f, x, 1e-6, 1e-4) +} diff --git a/pkg/autograd/ops.go b/pkg/autograd/ops.go index d6d4d50..5f460ef 100644 --- a/pkg/autograd/ops.go +++ b/pkg/autograd/ops.go @@ -2,8 +2,8 @@ package autograd import ( "github.com/Hirogava/Go-NN-Learn/pkg/matrix" - "github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph" "github.com/Hirogava/Go-NN-Learn/pkg/tensor" + "github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph" ) // Add @@ -72,8 +72,10 @@ type MatMul struct { B *tensor.Tensor } -type Transpose struct { +type TransposeOp struct { Parents []*graph.Node + Perm []int + InvPerm []int } func (op *MatMul) Backward(grad *tensor.Tensor) { @@ -104,16 +106,24 @@ func (op *MatMul) Backward(grad *tensor.Tensor) { op.Parents[1].Grad = gB } -func (op *Transpose) Backward(grad *tensor.Tensor) { - if op.Parents[0].Grad == nil { - op.Parents[0].Grad = tensor.Zeros(op.Parents[0].Value.Shape...) +func (op *TransposeOp) Backward(grad *tensor.Tensor) { + + p := op.Parents[0] + + gradIn, err := tensor.Transpose(grad) + if err != nil { + panic(err) } - gradM := matrix.TensorToMatrix(grad) - gradTransposed, _ := matrix.Transposition(gradM) - gradTransposedT := matrix.MatrixToTensor(gradTransposed) - g, _ := tensor.Add(op.Parents[0].Grad, gradTransposedT) - op.Parents[0].Grad = g + if p.Grad == nil { + p.Grad = gradIn + } else { + g, err := tensor.Add(p.Grad, gradIn) + if err != nil { + panic(err) + } + p.Grad = g + } } func (e *Engine) MatMul(a, b *graph.Node) *graph.Node { @@ -137,7 +147,7 @@ func (e *Engine) Transpose(a *graph.Node) *graph.Node { return nil } val := matrix.MatrixToTensor(valM) - op := &Transpose{Parents: []*graph.Node{a}} + op := &TransposeOp{Parents: []*graph.Node{a}} n := graph.NewNode(val, []*graph.Node{a}, op) e.Nodes = append(e.Nodes, n) return n @@ -248,17 +258,18 @@ type ReshapeOp struct { func (op *ReshapeOp) Backward(grad *tensor.Tensor) { p := op.Parents[0] - if p.Grad == nil { - p.Grad = tensor.Zeros(op.InShape...) - } + gradIn, err := tensor.Reshape(grad, op.InShape) if err != nil { - return + panic(err) } - g, _ := tensor.Add(p.Grad, gradIn) - p.Grad = g -} + if p.Grad == nil { + p.Grad = gradIn + } else { + p.Grad, _ = tensor.Add(p.Grad, gradIn) + } +} func (e *Engine) Reshape(a *graph.Node, newShape []int) *graph.Node { val, err := tensor.Reshape(a.Value, newShape) if err != nil {