Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion pkg/autograd/grad_example_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package autograd

import (
"math"
"testing"

"github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph"
"github.com/Hirogava/Go-NN-Learn/pkg/tensor"
"github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph"
)

// TestExampleDotSumGradient строит примитивный граф y = sum(a * b) и проверяет его градиента.
Expand Down Expand Up @@ -44,3 +45,76 @@ func TestExampleDotSumGradient(t *testing.T) {
}
}
}

// численный градиент
func numericalGrad(f func(x *graph.Node) *graph.Node, x *tensor.Tensor, eps float64) *tensor.Tensor {
grad := tensor.Zeros(x.Shape...)
for i := range x.Data {
orig := x.Data[i]

x.Data[i] = orig + eps
y1 := SumTensor(f(&graph.Node{Value: x}).Value)

x.Data[i] = orig - eps
y2 := SumTensor(f(&graph.Node{Value: x}).Value)

grad.Data[i] = (y1 - y2) / (2 * eps)
x.Data[i] = orig
}
return grad
}

func SumTensor(t *tensor.Tensor) float64 {
s := 0.0
for _, v := range t.Data {
s += v
}
return s
}

// grad_check
func gradCheck(t *testing.T, f func(x *graph.Node) *graph.Node, x *tensor.Tensor, eps, tol float64) {
node := &graph.Node{Value: x}
node.Grad = nil

// forward + backward
eng := NewEngine()
y := f(node)
eng.Backward(y)

// численный градиент
numGrad := numericalGrad(f, x, eps)

// сравнение
for i := range x.Data {
a := node.Grad.Data[i]
n := numGrad.Data[i]
if math.Abs(a-n) > tol {
t.Fatalf("grad check failed at index %d: analytic=%v, numeric=%v", i, a, n)
}
}
}

// тест Reshape
func TestReshapeGradCheck(t *testing.T) {
x := tensor.Randn([]int{2, 3}, 42) // фиксированный seed
f := func(xn *graph.Node) *graph.Node {
eng := NewEngine()
return eng.Reshape(xn, []int{3, 2})
}

gradCheck(t, f, x, 1e-6, 1e-4)
}

// тест Transpose
func TestTransposeGradCheck(t *testing.T) {
x := tensor.Randn([]int{2, 3}, 42)
f := func(xn *graph.Node) *graph.Node {

eng := NewEngine()
y := eng.Transpose(xn)
return eng.Sum(y) // скалярный выход
}

gradCheck(t, f, x, 1e-6, 1e-4)
}
47 changes: 29 additions & 18 deletions pkg/autograd/ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package autograd

import (
"github.com/Hirogava/Go-NN-Learn/pkg/matrix"
"github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph"
"github.com/Hirogava/Go-NN-Learn/pkg/tensor"
"github.com/Hirogava/Go-NN-Learn/pkg/tensor/graph"
)

// Add
Expand Down Expand Up @@ -72,8 +72,10 @@ type MatMul struct {
B *tensor.Tensor
}

type Transpose struct {
type TransposeOp struct {
Parents []*graph.Node
Perm []int
InvPerm []int
}

func (op *MatMul) Backward(grad *tensor.Tensor) {
Expand Down Expand Up @@ -104,16 +106,24 @@ func (op *MatMul) Backward(grad *tensor.Tensor) {
op.Parents[1].Grad = gB
}

func (op *Transpose) Backward(grad *tensor.Tensor) {
if op.Parents[0].Grad == nil {
op.Parents[0].Grad = tensor.Zeros(op.Parents[0].Value.Shape...)
func (op *TransposeOp) Backward(grad *tensor.Tensor) {

p := op.Parents[0]

gradIn, err := tensor.Transpose(grad)
if err != nil {
panic(err)
}

gradM := matrix.TensorToMatrix(grad)
gradTransposed, _ := matrix.Transposition(gradM)
gradTransposedT := matrix.MatrixToTensor(gradTransposed)
g, _ := tensor.Add(op.Parents[0].Grad, gradTransposedT)
op.Parents[0].Grad = g
if p.Grad == nil {
p.Grad = gradIn
} else {
g, err := tensor.Add(p.Grad, gradIn)
if err != nil {
panic(err)
}
p.Grad = g
}
}

func (e *Engine) MatMul(a, b *graph.Node) *graph.Node {
Expand All @@ -137,7 +147,7 @@ func (e *Engine) Transpose(a *graph.Node) *graph.Node {
return nil
}
val := matrix.MatrixToTensor(valM)
op := &Transpose{Parents: []*graph.Node{a}}
op := &TransposeOp{Parents: []*graph.Node{a}}
n := graph.NewNode(val, []*graph.Node{a}, op)
e.Nodes = append(e.Nodes, n)
return n
Expand Down Expand Up @@ -248,17 +258,18 @@ type ReshapeOp struct {

func (op *ReshapeOp) Backward(grad *tensor.Tensor) {
p := op.Parents[0]
if p.Grad == nil {
p.Grad = tensor.Zeros(op.InShape...)
}

gradIn, err := tensor.Reshape(grad, op.InShape)
if err != nil {
return
panic(err)
}
g, _ := tensor.Add(p.Grad, gradIn)
p.Grad = g
}

if p.Grad == nil {
p.Grad = gradIn
} else {
p.Grad, _ = tensor.Add(p.Grad, gradIn)
}
}
func (e *Engine) Reshape(a *graph.Node, newShape []int) *graph.Node {
val, err := tensor.Reshape(a.Value, newShape)
if err != nil {
Expand Down