Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
# No easy way to test these, so only kept as "manual inspection" tests locally
evaluation_test.go
search_test.go

# various .nnue models (alternative models or models for testing)
models/

bin/
todo.md
cmp.py

# Perft output comparison files
perft1
perft2
cpu.prof

# Profiling
*.prof

# SPRT + opening book
*.pgn

# some opening books?
*.epd

# fastchess
config.json

__pycache__/
.DS_Store

# ---------------------

# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
Expand Down
10 changes: 9 additions & 1 deletion engine/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (pos *Position) EndgameMaterial(color uint8) int32 {
return ans
}

func Evaluate(pos *Position) int32 {
func EvaluateHCE(pos *Position) int32 {
us := pos.Turn
them := pos.Turn ^ 1

Expand Down Expand Up @@ -201,3 +201,11 @@ func Evaluate(pos *Position) int32 {

return eval
}

func EvaluateNNUE(pos *Position) int32 {
return int32(pos.Nnue.Evaluate(pos.Turn) * 1000)
}

func Evaluate(pos *Position) int32 {
return EvaluateNNUE(pos)
}
242 changes: 242 additions & 0 deletions engine/nnue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package engine

import (
"encoding/binary"
"errors"
"fmt"
"io"
"os"
)

const (
Magic = "NNUE"
Version = 1
)

type NNUE struct {
NumInputs int
L1 int

WInput []float32 // [L1 * NumInputs]
BInput []float32 // [L1]
WOutput []float32 // [2 * L1]
BOutput float32

FeatureCols [][]float32

Acc Accumulator
}

type Accumulator struct {
Values [2][]float32
}

// TODO: Accumulator stack

// for black's perspective, the board is flipped for evaluation purposes
// this way the first layer parameters only has to "learn" how to play one perspective, which helps with generalization (?)
func FeatureIndex(perspective uint8, pieceColor uint8, pieceType uint8, sq Square) uint16 {
friendly := perspective == pieceColor
pieceIdx := pieceType
if !friendly {
pieceIdx += 6
}
if perspective == Black {
sq ^= FlipVertical
}
return 64*uint16(pieceIdx) + uint16(sq)
}

func LoadNNUE(path string) (*NNUE, error) {
// resolve path (either user-provided or embedded default written to a temp file)
resolvedPath, cleanup, err := resolveNNUEPath(path)
if err != nil {
return nil, err
}
// remove temp file (if any) when done
if cleanup != nil {
defer cleanup()
}

f, err := os.Open(resolvedPath)
if err != nil {
return nil, err
}
defer f.Close()

// header
magic := make([]byte, 4)
if _, err := io.ReadFull(f, magic); err != nil {
return nil, err
}
if string(magic) != Magic {
return nil, errors.New("invalid NNUE magic")
}

var version uint32
if err := binary.Read(f, binary.LittleEndian, &version); err != nil {
return nil, err
}
if version != Version {
return nil, fmt.Errorf("unsupported NNUE version %d", version)
}

// read as uint32 since python wrote 32 bit ints
var numInputs32, l132 uint32
if err := binary.Read(f, binary.LittleEndian, &numInputs32); err != nil {
return nil, err
}
if err := binary.Read(f, binary.LittleEndian, &l132); err != nil {
return nil, err
}

numInputs := int(numInputs32)
l1 := int(l132)

nnue := &NNUE{
NumInputs: numInputs,
L1: l1,
}

// read network parameters
nnue.WInput = make([]float32, l1*numInputs)
nnue.BInput = make([]float32, l1)
nnue.WOutput = make([]float32, 2*l1)

if err := binary.Read(f, binary.LittleEndian, &nnue.WInput); err != nil {
return nil, err
}
if err := binary.Read(f, binary.LittleEndian, &nnue.BInput); err != nil {
return nil, err
}
if err := binary.Read(f, binary.LittleEndian, &nnue.WOutput); err != nil {
return nil, err
}
if err := binary.Read(f, binary.LittleEndian, &nnue.BOutput); err != nil {
return nil, err
}

// build helper representation of first layer's weights
nnue.BuildFeatureCols()

// allocate accumulators for both sides
nnue.Acc.Values[0] = make([]float32, l1)
nnue.Acc.Values[1] = make([]float32, l1)

return nnue, nil
}

// must be called first before using anything
func (nnue *NNUE) BuildFeatureCols() {
numInputs := int(nnue.NumInputs)
l1 := int(nnue.L1)
cols := make([][]float32, numInputs)
for f := 0; f < numInputs; f++ {
col := make([]float32, l1)
for o := 0; o < l1; o++ {
col[o] = nnue.WInput[o*numInputs+f]
}
cols[f] = col
}
nnue.FeatureCols = cols
}

// add a whole set of initial features (overwriting existing features).
// this is the only place where input bias is added
func (nnue *NNUE) Refresh(features []uint16, perspective uint8) {
acc := nnue.Acc.Values[perspective]
copy(acc, nnue.BInput)
for _, f := range features {
col := nnue.FeatureCols[f]
for o := 0; o < len(acc); o += 16 {
acc[o] += col[o]
acc[o+1] += col[o+1]
acc[o+2] += col[o+2]
acc[o+3] += col[o+3]
acc[o+4] += col[o+4]
acc[o+5] += col[o+5]
acc[o+6] += col[o+6]
acc[o+7] += col[o+7]
acc[o+8] += col[o+8]
acc[o+9] += col[o+9]
acc[o+10] += col[o+10]
acc[o+11] += col[o+11]
acc[o+12] += col[o+12]
acc[o+13] += col[o+13]
acc[o+14] += col[o+14]
acc[o+15] += col[o+15]
}
}
}

// refresh both perspectives
func (nnue *NNUE) RefreshAll(featuresW, featuresB []uint16) {
nnue.Refresh(featuresW, White)
nnue.Refresh(featuresB, Black)
}

// incrementally add a feature
// NOTE: only using Add() is incorrect, since no bias is added
// remember to also to perform an empty refresh? (ex: in FromFEN)
func (nnue *NNUE) Add(feature uint16, perspective uint8) {
col := nnue.FeatureCols[feature]
acc := nnue.Acc.Values[perspective]
for o := 0; o < len(acc); o += 16 {
acc[o] += col[o]
acc[o+1] += col[o+1]
acc[o+2] += col[o+2]
acc[o+3] += col[o+3]
acc[o+4] += col[o+4]
acc[o+5] += col[o+5]
acc[o+6] += col[o+6]
acc[o+7] += col[o+7]
acc[o+8] += col[o+8]
acc[o+9] += col[o+9]
acc[o+10] += col[o+10]
acc[o+11] += col[o+11]
acc[o+12] += col[o+12]
acc[o+13] += col[o+13]
acc[o+14] += col[o+14]
acc[o+15] += col[o+15]
}
}

// incrementally remove a feature
func (nnue *NNUE) Remove(feature uint16, perspective uint8) {
col := nnue.FeatureCols[feature]
acc := nnue.Acc.Values[perspective]
for o := 0; o < len(acc); o += 16 {
acc[o] -= col[o]
acc[o+1] -= col[o+1]
acc[o+2] -= col[o+2]
acc[o+3] -= col[o+3]
acc[o+4] -= col[o+4]
acc[o+5] -= col[o+5]
acc[o+6] -= col[o+6]
acc[o+7] -= col[o+7]
acc[o+8] -= col[o+8]
acc[o+9] -= col[o+9]
acc[o+10] -= col[o+10]
acc[o+11] -= col[o+11]
acc[o+12] -= col[o+12]
acc[o+13] -= col[o+13]
acc[o+14] -= col[o+14]
acc[o+15] -= col[o+15]
}
}

func (nnue *NNUE) Evaluate(side uint8) float32 {
ourAcc := nnue.Acc.Values[side]
theirAcc := nnue.Acc.Values[1-side]
var result float32 = 0.0

// ReLU before output layer
for i := 0; i < nnue.L1; i++ {
result += nnue.WOutput[i] * max(ourAcc[i], 0.0)
}
for j := 0; j < nnue.L1; j++ {
result += nnue.WOutput[nnue.L1+j] * max(theirAcc[j], 0.0)
}
result += nnue.BOutput
return result
}
Binary file added engine/nnue/256.nnue
Binary file not shown.
71 changes: 71 additions & 0 deletions engine/nnue_loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package engine

import (
"embed"
"fmt"
"os"
)

//go:embed nnue/*.nnue
var nnueFS embed.FS

// resolveNNUEPath returns a filesystem path to open for the NNUE data
//
// If `path` is empty: it writes the embedded default NNUE to a temp file and
// returns that temp path and a cleanup func that should be deferred by caller
//
// If `path` is non-empty it is returned as-is (no cleanup function)
//
// error - failure if neither the provided filepath nor the embedded NNUE is available
func resolveNNUEPath(path string) (string, func(), error) {
// caller provides path
if path != "" {
// caller is responsible for not deleting this file
return path, func() {}, nil
}
// embedded default
return writeEmbeddedToTemp("nnue/256.nnue")
}

// write the embedded file `embeddedName` to a temporary
// file and returns the temp path and a cleanup func that removes it
func writeEmbeddedToTemp(embeddedName string) (string, func(), error) {
data, err := nnueFS.ReadFile(embeddedName)
if err != nil {
return "", nil, fmt.Errorf("embedded default nnue not found (%s): %w", embeddedName, err)
}

// tmp file at /tmp/nnue-xxxxxx.nnue
tmp, err := os.CreateTemp("", "nnue-*.nnue")
if err != nil {
return "", nil, err
}
// close tmp file at the end
tmpName := tmp.Name()
defer func() {
_ = tmp.Close()
}()

// write NNUE data to tmp file
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpName)
return "", nil, err
}
// force data to disk
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpName)
return "", nil, err
}
// close tmp file
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpName)
return "", nil, err
}
// cleanup function -- deleting tmp file
cleanup := func() {
_ = os.Remove(tmpName)
}
return tmpName, cleanup, nil
}
Loading
Loading