tnn
go版本神经网络框架,支持模型训练和预估
工具
- minfo: 这是tnn框架中的一个工具,用于查看保存模型的定义信息
示例
- mlp: 该示例是一个四层的神经网络,用于训练异或运算(回归)
- cnn: 该示例是著名的手写数字识别示例,通过CNN网络进行数字识别(分类)
- sin: 该示例使用rnn网络来进行sin函数的时序任务训练(回归)
构造网络
首先定义网络的每个层
initializer := initializer.NewXavierUniform(1)
var net net.Net
net.Set(
layer.NewDense(16, initializer),
activation.NewSigmoid(),
layer.NewDense(8, initializer),
activation.NewSigmoid(),
layer.NewDense(4, initializer),
activation.NewSigmoid(),
layer.NewDense(2, initializer),
activation.NewSigmoid(),
layer.NewDense(1, initializer),
)
选定一个loss函数和优化器
loss := loss.NewMSE()
optimizer := optimizer.NewAdam(lr, 0, 0.9, 0.999, 1e-8)
最后构造出模型并进行模型训练
m := model.New(&net, loss, optimizer)
for i := 0; i < 10; i++ {
m.Train(input, output)
loss := m.Loss(input, output)
fmt.Printf("Epoch: %d, Loss: %.05f\n", i, loss)
}
模型预测方法如下
pred := model.Predict(input)
感谢