optim

package
v0.7.7 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 6, 2026 License: Apache-2.0 Imports: 4 Imported by: 0

Documentation

Overview

Package optim implements optimization algorithms for training neural networks.

This package provides:

  • Optimizer interface: Base interface for all optimizers
  • SGD: Stochastic Gradient Descent with momentum
  • Adam: Adaptive Moment Estimation

Design inspired by PyTorch's torch.optim but adapted for Go with type safety.

Example usage:

// Create optimizer
optimizer := optim.NewAdam(model.Parameters(), optim.AdamConfig{
    LR: 0.001,
})

// Training loop
for epoch := range epochs {
    loss := computeLoss(model, data)

    // Compute gradients
    backend.Tape().StartRecording()
    output := model.Forward(input)
    loss := lossFunc.Forward(output, targets)
    grads := autodiff.Backward(loss, backend)

    // Update parameters
    optimizer.Step(grads)
    optimizer.ZeroGrad()
}

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Adam

type Adam[B tensor.Backend] struct {
	// contains filtered or unexported fields
}

Adam implements the Adam (Adaptive Moment Estimation) optimizer.

Adam combines ideas from RMSprop and momentum:

  • Maintains exponential moving averages of gradients (first moment)
  • Maintains exponential moving averages of squared gradients (second moment)
  • Applies bias correction to compensate for initialization at zero

Update rule:

m_t = beta1 * m_{t-1} + (1-beta1) * gradient       // First moment
v_t = beta2 * v_{t-1} + (1-beta2) * gradient²      // Second moment
m_hat = m_t / (1 - beta1^t)                        // Bias correction
v_hat = v_t / (1 - beta2^t)                        // Bias correction
param = param - lr * m_hat / (sqrt(v_hat) + eps)  // Parameter update

Adam is particularly well-suited for:

  • Large datasets and high-dimensional parameter spaces
  • Non-stationary objectives and sparse gradients
  • Problems with very noisy/sparse gradients

Reference: "Adam: A Method for Stochastic Optimization" (Kingma & Ba, 2014)

Example:

optimizer := optim.NewAdam(model.Parameters(), optim.AdamConfig{
    LR:    0.001,
    Betas: [2]float32{0.9, 0.999},
    Eps:   1e-8,
})

for epoch := range epochs {
    loss := train_step(model, batch)
    grads := autodiff.Backward(loss, backend)
    optimizer.Step(grads)
    optimizer.ZeroGrad()
}

func NewAdam

func NewAdam[B tensor.Backend](params []*nn.Parameter[B], config AdamConfig, backend B) *Adam[B]

NewAdam creates a new Adam optimizer.

Parameters:

  • params: Model parameters to optimize
  • config: Adam configuration (LR, Betas, Eps)

Returns a new Adam optimizer with default hyperparameters if not specified.

Default hyperparameters:

  • LR: 0.001
  • Beta1: 0.9
  • Beta2: 0.999
  • Eps: 1e-8

func (*Adam[B]) GetLR

func (a *Adam[B]) GetLR() float32

GetLR returns the current learning rate.

func (*Adam[B]) GetTimestep

func (a *Adam[B]) GetTimestep() int

GetTimestep returns the current timestep.

Useful for monitoring optimizer state.

func (*Adam[B]) LoadStateDict added in v0.5.4

func (a *Adam[B]) LoadStateDict(stateDict map[string]*tensor.RawTensor) error

LoadStateDict loads optimizer state from serialization.

Restores first and second moment estimates for Adam optimizer.

Parameters:

  • stateDict: Map from state name to RawTensor

Returns an error if moment shapes don't match parameter shapes.

func (*Adam[B]) SetLR

func (a *Adam[B]) SetLR(lr float32)

SetLR updates the learning rate.

Useful for learning rate scheduling during training.

func (*Adam[B]) StateDict added in v0.5.4

func (a *Adam[B]) StateDict() map[string]*tensor.RawTensor

StateDict returns the optimizer state for serialization.

For Adam, this exports:

  • First moment estimates (m) for each parameter: "m.{param_index}"
  • Second moment estimates (v) for each parameter: "v.{param_index}"

Returns a map from state name to RawTensor.

func (*Adam[B]) Step

func (a *Adam[B]) Step(grads map[*tensor.RawTensor]*tensor.RawTensor)

Step performs a single optimization step using Adam algorithm.

Applies Adam update to all parameters:

  1. Update biased first moment estimate
  2. Update biased second moment estimate
  3. Compute bias-corrected moment estimates
  4. Update parameters

Parameters with no gradient are skipped.

func (*Adam[B]) ZeroGrad

func (a *Adam[B]) ZeroGrad()

ZeroGrad clears gradients for all parameters.

type AdamConfig

type AdamConfig struct {
	LR    float32    // Learning rate (default: 0.001)
	Betas [2]float32 // Coefficients for computing running averages (default: [0.9, 0.999])
	Eps   float32    // Term for numerical stability (default: 1e-8)
}

AdamConfig holds configuration for Adam optimizer.

type Config

type Config struct {
	LR float32 // Learning rate
}

Config is the base configuration for all optimizers.

type Optimizer

type Optimizer interface {
	// Step applies gradient updates to all parameters.
	//
	// Takes a gradient map from Backward() and updates parameters in-place.
	// The gradient map should contain RawTensor -> gradient mapping.
	//
	// Example:
	//   grads := autodiff.Backward(loss, backend)
	//   optimizer.Step(grads)
	Step(grads map[*tensor.RawTensor]*tensor.RawTensor)

	// ZeroGrad clears all parameter gradients.
	//
	// This should be called before each backward pass to prevent
	// gradient accumulation from previous iterations.
	//
	// Example:
	//   optimizer.ZeroGrad()
	//   loss := model.Forward(...)
	//   grads := autodiff.Backward(loss, backend)
	ZeroGrad()

	// GetLR returns the current learning rate.
	//
	// Useful for monitoring and learning rate scheduling.
	GetLR() float32

	// StateDict returns the optimizer state for serialization.
	//
	// This includes optimizer-specific state like momentum buffers (SGD)
	// or moment estimates (Adam). Used for checkpoint saving.
	//
	// Returns a map from state name to RawTensor.
	// State names follow the pattern: "{state_type}.{param_index}"
	// For example: "velocity.0", "m.0", "v.0"
	StateDict() map[string]*tensor.RawTensor

	// LoadStateDict loads optimizer state from serialization.
	//
	// Restores optimizer-specific state from a checkpoint. The state
	// dictionary should match the structure returned by StateDict().
	//
	// Parameters:
	//   - stateDict: Map from state name to RawTensor
	//
	// Returns an error if the state dictionary is invalid.
	LoadStateDict(stateDict map[string]*tensor.RawTensor) error
}

Optimizer is the base interface for all optimization algorithms.

Optimizers update model parameters based on computed gradients to minimize the loss function during training.

All optimizers must implement:

  • Step: Apply gradient updates to parameters
  • ZeroGrad: Clear gradients before next iteration
  • GetLR: Get current learning rate (for monitoring/scheduling)
  • StateDict: Export optimizer state for checkpoints
  • LoadStateDict: Import optimizer state from checkpoints

type SGD

type SGD[B tensor.Backend] struct {
	// contains filtered or unexported fields
}

SGD implements Stochastic Gradient Descent optimizer with optional momentum.

Update rule without momentum:

param = param - lr * gradient

Update rule with momentum:

velocity = momentum * velocity + gradient
param = param - lr * velocity

Momentum helps accelerate SGD in relevant directions and dampens oscillations.

Example:

optimizer := optim.NewSGD(model.Parameters(), optim.SGDConfig{
    LR:       0.01,
    Momentum: 0.9,
})

for epoch := range epochs {
    loss := train_step(model, batch)
    grads := autodiff.Backward(loss, backend)
    optimizer.Step(grads)
    optimizer.ZeroGrad()
}

func NewSGD

func NewSGD[B tensor.Backend](params []*nn.Parameter[B], config SGDConfig, backend B) *SGD[B]

NewSGD creates a new SGD optimizer.

Parameters:

  • params: Model parameters to optimize
  • config: SGD configuration (LR, Momentum)

Returns a new SGD optimizer.

Example:

sgd := optim.NewSGD(model.Parameters(), optim.SGDConfig{
    LR:       0.01,
    Momentum: 0.9,
})

func (*SGD[B]) GetLR

func (s *SGD[B]) GetLR() float32

GetLR returns the current learning rate.

func (*SGD[B]) LoadStateDict added in v0.5.4

func (s *SGD[B]) LoadStateDict(stateDict map[string]*tensor.RawTensor) error

LoadStateDict loads optimizer state from serialization.

Restores velocity buffers for SGD with momentum. If momentum is 0, ignores the provided state (no velocities needed).

Parameters:

  • stateDict: Map from state name to RawTensor

Returns an error if velocity shapes don't match parameter shapes.

func (*SGD[B]) SetLR

func (s *SGD[B]) SetLR(lr float32)

SetLR updates the learning rate.

Useful for learning rate scheduling during training.

func (*SGD[B]) StateDict added in v0.5.4

func (s *SGD[B]) StateDict() map[string]*tensor.RawTensor

StateDict returns the optimizer state for serialization.

For SGD with momentum, this exports velocity buffers for each parameter. Without momentum, returns an empty map.

State keys: "velocity.{param_index}" -> velocity tensor.

func (*SGD[B]) Step

func (s *SGD[B]) Step(grads map[*tensor.RawTensor]*tensor.RawTensor)

Step performs a single optimization step.

Applies gradient descent update to all parameters:

  • Without momentum: param -= lr * grad
  • With momentum: velocity = momentum * velocity + grad, param -= lr * velocity

Parameters with no gradient (not in computational graph) are skipped.

func (*SGD[B]) ZeroGrad

func (s *SGD[B]) ZeroGrad()

ZeroGrad clears gradients for all parameters.

type SGDConfig

type SGDConfig struct {
	LR       float32 // Learning rate (default: 0.01)
	Momentum float32 // Momentum factor (default: 0.0, range: [0, 1))
}

SGDConfig holds configuration for SGD optimizer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL