training

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 25, 2025 License: Apache-2.0 Imports: 9 Imported by: 0

Documentation

Overview

Package training provides tools for training neural networks.

Package training provides the V2 trainer API using a Batch and pluggable strategy.

Package training defines training-time gradient computation strategies.

Package training provides core components for neural network training.

Package training defines default backpropagation strategy.

Package training defines the one-step gradient approximation strategy.

Package training provides tools for training neural networks.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Batch added in v0.3.0

type Batch[T tensor.Numeric] struct {
	Inputs  map[graph.Node[T]]*tensor.TensorNumeric[T]
	Targets *tensor.TensorNumeric[T]
}

Batch groups the stable inputs for a single training step.

Inputs are provided as a map keyed by the graph's input nodes. Targets are provided as a single tensor; strategies may interpret targets appropriately for the chosen loss.

type DefaultBackpropStrategy added in v0.3.0

type DefaultBackpropStrategy[T tensor.Numeric] struct{}

DefaultBackpropStrategy performs standard backpropagation through the loss and model graph.

func NewDefaultBackpropStrategy added in v0.3.0

func NewDefaultBackpropStrategy[T tensor.Numeric]() *DefaultBackpropStrategy[T]

NewDefaultBackpropStrategy constructs a DefaultBackpropStrategy.

func (*DefaultBackpropStrategy[T]) ComputeGradients added in v0.3.0

func (s *DefaultBackpropStrategy[T]) ComputeGradients(
	ctx context.Context,
	g *graph.Graph[T],
	loss graph.Node[T],
	batch Batch[T],
) (T, error)

ComputeGradients runs forward pass, computes loss, runs backward passes, and leaves parameter gradients populated on the graph's parameters.

type DefaultTrainer added in v0.3.0

type DefaultTrainer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

DefaultTrainer encapsulates stable training components and delegates gradient computation to a strategy.

func NewDefaultTrainer added in v0.3.0

func NewDefaultTrainer[T tensor.Numeric](
	g *graph.Graph[T],
	loss graph.Node[T],
	optimizer opt.Optimizer[T],
	strategy GradientStrategy[T],
) *DefaultTrainer[T]

NewDefaultTrainer constructs a new DefaultTrainer. If strategy is nil, DefaultBackpropStrategy is used.

func (*DefaultTrainer[T]) TrainStep added in v0.3.0

func (t *DefaultTrainer[T]) TrainStep(
	ctx context.Context,
	g *graph.Graph[T],
	optimizer opt.Optimizer[T],
	inputs map[graph.Node[T]]*tensor.TensorNumeric[T],
	targets *tensor.TensorNumeric[T],
) (T, error)

TrainStep performs a single training step using the configured strategy and optimizer.

type EraSequencer added in v0.3.0

type EraSequencer struct {
	// contains filtered or unexported fields
}

EraSequencer generates sequences of consecutive eras for curriculum learning.

func NewEraSequencer added in v0.3.0

func NewEraSequencer(maxSeqLen int) *EraSequencer

NewEraSequencer creates a new era sequencer with the given maximum sequence length.

func (*EraSequencer) GenerateSequences added in v0.3.0

func (s *EraSequencer) GenerateSequences(dataset *data.Dataset, numSequences int) []*data.Dataset

GenerateSequences generates random sequences of consecutive eras from the dataset. Each sequence contains between 1 and maxSeqLen consecutive eras.

func (*EraSequencer) GenerateTrainValidationSplit added in v0.3.0

func (s *EraSequencer) GenerateTrainValidationSplit(dataset *data.Dataset, validationEras int) (*data.Dataset, *data.Dataset)

GenerateTrainValidationSplit splits the dataset into training and validation sets. The validation set contains the last 'validationEras' eras chronologically.

func (*EraSequencer) SetSeed added in v0.3.0

func (s *EraSequencer) SetSeed(seed1, seed2 uint64)

SetSeed sets the random seed for reproducible sequence generation.

type GradientStrategy added in v0.3.0

type GradientStrategy[T tensor.Numeric] interface {
	ComputeGradients(
		ctx context.Context,
		g *graph.Graph[T],
		loss graph.Node[T],
		batch Batch[T],
	) (lossValue T, err error)
}

GradientStrategy encapsulates how to compute gradients for a training step.

Implementations may perform standard backprop through the loss, use approximations, or incorporate auxiliary losses (e.g., deep supervision). The strategy must leave parameter gradients populated on the graph's parameters so that the optimizer can apply updates afterwards.

type Model

type Model[T tensor.Numeric] interface {
	// Forward performs the forward pass of the model.
	Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
	// Backward performs the backward pass of the model.
	Backward(ctx context.Context, grad *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
	// Parameters returns the parameters of the model.
	Parameters() []*graph.Parameter[T]
}

Model defines the interface for a trainable model.

type OneStepApproximationStrategy added in v0.3.0

type OneStepApproximationStrategy[T tensor.Numeric] struct{}

OneStepApproximationStrategy performs a one-step gradient approximation. It is designed for training recurrent models without full BPTT.

func NewOneStepApproximationStrategy added in v0.3.0

func NewOneStepApproximationStrategy[T tensor.Numeric]() *OneStepApproximationStrategy[T]

NewOneStepApproximationStrategy constructs a OneStepApproximationStrategy.

func (*OneStepApproximationStrategy[T]) ComputeGradients added in v0.3.0

func (s *OneStepApproximationStrategy[T]) ComputeGradients(
	ctx context.Context,
	g *graph.Graph[T],
	loss graph.Node[T],
	batch Batch[T],
) (T, error)

ComputeGradients performs a forward pass and a one-step backward pass.

type Trainer

type Trainer[T tensor.Numeric] interface {
	// TrainStep performs a single training step for a model.
	// It takes the model's parameters, the optimizer, and the input/target data.
	// It is responsible for computing the loss, gradients, and updating the parameters.
	TrainStep(
		ctx context.Context,
		modelGraph *graph.Graph[T],
		optimizer optimizer.Optimizer[T],
		inputs map[graph.Node[T]]*tensor.TensorNumeric[T],
		targets *tensor.TensorNumeric[T],
	) (loss T, err error)
}

Trainer is an interface for model-specific training orchestration.

Directories

Path Synopsis
Package loss provides various loss functions for neural networks.
Package loss provides various loss functions for neural networks.
Package optimizer provides various optimization algorithms for neural networks.
Package optimizer provides various optimization algorithms for neural networks.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL