graph

package
v0.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 4, 2025 License: Apache-2.0 Imports: 5 Imported by: 1

Documentation

Overview

Package graph provides a computational graph abstraction. Package graph provides a computational graph abstraction.

Index

Constants

This section is empty.

Variables

View Source
var ErrInvalidInputCount = errors.New("invalid number of input tensors")

ErrInvalidInputCount is returned when the number of input tensors is invalid.

Functions

This section is empty.

Types

type Builder

type Builder[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Builder provides a fluent API for constructing a computation graph.

func NewBuilder

func NewBuilder[T tensor.Numeric](engine compute.Engine[T]) *Builder[T]

NewBuilder creates a new graph builder.

func (*Builder[T]) AddNode

func (b *Builder[T]) AddNode(node Node[T], inputs ...Node[T]) Node[T]

AddNode adds a new node to the graph with the given inputs.

func (*Builder[T]) Build

func (b *Builder[T]) Build(outputNode Node[T]) (func(inputs ...*tensor.Tensor[T]) (*tensor.Tensor[T], error), func(initialGradient *tensor.Tensor[T]) error, error)

Build constructs the final graph and returns forward and backward functions.

func (*Builder[T]) Input

func (b *Builder[T]) Input(shape []int) Node[T]

Input creates a new input node.

func (*Builder[T]) Parameters

func (b *Builder[T]) Parameters() []*Parameter[T]

Parameters returns all the trainable parameters in the graph.

type Graph

type Graph[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Graph represents a computation graph with a defined execution order.

func (*Graph[T]) Backward

func (g *Graph[T]) Backward(initialGradient *tensor.Tensor[T]) error

Backward executes the backward pass of the entire graph.

func (*Graph[T]) Forward

func (g *Graph[T]) Forward(inputs ...*tensor.Tensor[T]) (*tensor.Tensor[T], error)

Forward executes the forward pass of the entire graph.

func (*Graph[T]) Parameters

func (g *Graph[T]) Parameters() []*Parameter[T]

Parameters returns all the trainable parameters in the graph.

type NoParameters

type NoParameters[T tensor.Numeric] struct{}

NoParameters is a helper struct for layers that have no parameters. It provides a default implementation of the Parameters() method.

func (*NoParameters[T]) Parameters

func (np *NoParameters[T]) Parameters() []*Parameter[T]

Parameters returns an empty slice of parameters.

type Node

type Node[T tensor.Numeric] interface {
	// OutputShape returns the shape of the output tensor.
	OutputShape() []int
	// Forward computes the output of the node given the inputs.
	Forward(inputs ...*tensor.Tensor[T]) (*tensor.Tensor[T], error)
	// Backward computes the gradients of the loss with respect to the inputs and parameters.
	Backward(outputGradient *tensor.Tensor[T]) ([]*tensor.Tensor[T], error)
	// Parameters returns the parameters of the node.
	Parameters() []*Parameter[T]
}

Node represents a node in the computation graph.

type Parameter

type Parameter[T tensor.Numeric] struct {
	Name     string
	Value    *tensor.Tensor[T]
	Gradient *tensor.Tensor[T]
}

Parameter is a container for a trainable tensor (e.g., weights or biases). It holds both the tensor for its value and a tensor for its gradient.

func NewParameter

func NewParameter[T tensor.Numeric](name string, value *tensor.Tensor[T], newTensorFn func(shape []int, data []T) (*tensor.Tensor[T], error)) (*Parameter[T], error)

NewParameter creates a new parameter, initializing its gradient tensor with the same shape. It takes a tensor creation function to allow for mocking in tests.

func (*Parameter[T]) AddGradient

func (p *Parameter[T]) AddGradient(grad *tensor.Tensor[T]) error

AddGradient accumulates the given gradient to the parameter's gradient.

func (*Parameter[T]) ClearGradient

func (p *Parameter[T]) ClearGradient()

ClearGradient sets the parameter's gradient to zero.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL