simlpe stands for SImple Multi Layer PErceptron
It a very simple implementation freely inspired from this post.
Check the provided examples of the xor function or the mnist database of handwritten digits
A network car be instanciated using NewNetwork and set the following parameters
- learning rate
- number of input neurons
Then, add layers using AddLayer.
A new layer is created using :
- a builder (like a
LinearBuilderto initialize a linear layer) - the number of neurons of this new layer
- the activation function (
Sigmoid,Htan,ReLU...)
// Build a basic network
net := mlp.NewNetwork(0.3, 2) // input layer (2 neurons)
net.AddLayer(mlp.LinearBuilder{}, 3, mlp.Sigmoid{}) // hidden layer (3 neurons)
net.AddLayer(mlp.LinearBuilder{}, 1, mlp.Sigmoid{}) // output layer (1 neuron)In order to train the xor function, set the following data:
- Input data shall be a matrix of size
k(number of features) xn(number of input neurons). - Output data shall be a matrix of size
kxm(number of output neurons)
xData := [][]float64{
{0, 0},
{1, 0},
{0, 1},
{1, 1},
}
yData := [][]float64{{0}, {1}, {1}, {0}}Fill optional condition of stop the training. By default, only one epoch will be processed. If one or more criteria are filled, the first to be true will stop the training.
net.Stop.OnEpoch(10000) // max epoch = 10 000
net.Stop.OnDuration(time.Second) // max duration = 1 second
net.Stop.OnMeanSquaredError(0.0001) // min mean squared error is 1e-4Launch training using the Train function.
It returns the ending condition reached or an error if any.
term, err := net.Train(ctx, xData, yData)Use function Predict data for each input neurons to produce data for each output neurons.
net.Predict([]float64{0, 0}) // should be ~0
net.Predict([]float64{1, 0}) // should be ~1
net.Predict([]float64{0, 1}) // should be ~1
net.Predict([]float64{1, 1}) // should be ~0Use the json marshaler to read or write a network.
// Export the network
js, err := net.MarshalJSON()
// if err != nil ...
// Import the same network
net2 := mlp.Network{}
err2 := net2.UnmarshalJSON(js)
// if err2 != nil ...