ps

package
v0.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Sep 23, 2022 License: AGPL-3.0 Imports: 8 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Adam

type Adam struct {
	// contains filtered or unexported fields
}

Adam is an Adam solver

func NewAdam

func NewAdam(lr, beta, beta2, epsilon float64) *Adam

NewAdam returns a new Adam solver

func (*Adam) Init

func (o *Adam) Init(size int)

Init initializes vectors using number of weights in network

func (*Adam) Update

func (o *Adam) Update(value, gradient float64, t, idx int) float64

Update returns the update for a given weight

type BatchTrainer

type BatchTrainer struct {
	// contains filtered or unexported fields
}

BatchTrainer implements parallelized batch training

func NewBatchTrainer

func NewBatchTrainer(solver Solver, verbosity, batchSize, parallelism int) *BatchTrainer

NewBatchTrainer returns a BatchTrainer

func (*BatchTrainer) Train

func (t *BatchTrainer) Train(n *nn.Neural, examples, validation Samples, iterations int, shuffle bool)

Train trains n

type OnlineTrainer

type OnlineTrainer struct {
	// contains filtered or unexported fields
}

OnlineTrainer is a basic, online network trainer

func NewTrainer

func NewTrainer(solver Solver, verbosity int) *OnlineTrainer

NewTrainer creates a new trainer

func (*OnlineTrainer) Predict

func (t *OnlineTrainer) Predict(n *nn.Neural, input []float64) []float64

func (*OnlineTrainer) Train

func (t *OnlineTrainer) Train(n *nn.Neural, examples, validation Samples, iterations int, shuffle bool)

Train trains n

type SGD

type SGD struct {
	// contains filtered or unexported fields
}

SGD is stochastic gradient descent with nesterov/momentum

func NewSGD

func NewSGD(lr, momentum, decay float64, nesterov bool) *SGD

NewSGD returns a new SGD solver

func (*SGD) Init

func (o *SGD) Init(size int)

Init initializes vectors using number of weights in network

func (*SGD) Update

func (o *SGD) Update(value, gradient float64, iteration, idx int) float64

Update returns the update for a given weight

type Sample

type Sample struct {
	Input    []float64
	Response []float64
}

Sample is an input-target pair

type Samples

type Samples []Sample

Samples is a set of input-output pairs

func (Samples) Shuffle

func (e Samples) Shuffle()

Shuffle shuffles slice in-place

func (Samples) Split

func (e Samples) Split(p float64) (first, second Samples)

Split assigns each element to two new slices according to probability p

func (Samples) SplitN

func (e Samples) SplitN(n int) []Samples

SplitN splits slice into n parts

func (Samples) SplitSize

func (e Samples) SplitSize(size int) []Samples

SplitSize splits slice into parts of size size

type Solver

type Solver interface {
	Init(size int)
	Update(value, gradient float64, iteration, idx int) float64
}

Solver implements an update rule for training a NN

type StatsPrinter

type StatsPrinter struct {
	// contains filtered or unexported fields
}

StatsPrinter prints training progress

func NewStatsPrinter

func NewStatsPrinter() *StatsPrinter

NewStatsPrinter creates a StatsPrinter

func (*StatsPrinter) Init

func (p *StatsPrinter) Init(n *nn.Neural)

Init initializes printer

func (*StatsPrinter) PrintProgress

func (p *StatsPrinter) PrintProgress(n *nn.Neural, validation Samples, elapsed time.Duration, iteration int)

PrintProgress prints the current state of training

type Trainer

type Trainer interface {
	Train(n *nn.Neural, examples, validation Samples, iterations int, shuffle bool)
}

Trainer is a neural network trainer

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL