tnn

module
v0.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 9, 2023 License: MIT

README

tnn

go版本神经网络框架,支持模型训练和预估

工具

  • minfo: 这是tnn框架中的一个工具,用于查看保存模型的定义信息

示例

  • mlp: 该示例是一个四层的神经网络,用于训练异或运算(回归)
  • cnn: 该示例是著名的手写数字识别示例,通过CNN网络进行数字识别(分类)
  • sin: 该示例使用rnn网络来进行sin函数的时序任务训练(回归)

构造网络

首先定义网络的每个层

initializer := initializer.NewXavierUniform(1)
var net net.Net
net.Set(
    layer.NewDense(16, initializer),
    activation.NewSigmoid(),
    layer.NewDense(8, initializer),
    activation.NewSigmoid(),
    layer.NewDense(4, initializer),
    activation.NewSigmoid(),
    layer.NewDense(2, initializer),
    activation.NewSigmoid(),
    layer.NewDense(1, initializer),
)

选定一个loss函数和优化器

loss := loss.NewMSE()
optimizer := optimizer.NewAdam(lr, 0, 0.9, 0.999, 1e-8)

最后构造出模型并进行模型训练

m := model.New(&net, loss, optimizer)
for i := 0; i < 10; i++ {
    m.Train(input, output)
    loss := m.Loss(input, output)
    fmt.Printf("Epoch: %d, Loss: %.05f\n", i, loss)
}

模型预测方法如下

pred := model.Predict(input)

感谢

Directories

Path Synopsis
cmd
minfo command
example
mnist command
sin command
xor command
internal
pb
nn
net

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL