graph

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 25, 2025 License: Apache-2.0 Imports: 6 Imported by: 1

Documentation

Overview

Package graph provides a computational graph abstraction.

Index

Constants

This section is empty.

Variables

View Source
var ErrInvalidInputCount = errors.New("invalid number of inputs")

ErrInvalidInputCount is returned when the number of inputs to a node is incorrect.

Functions

This section is empty.

Types

type Builder

type Builder[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Builder provides a fluent API for constructing a computation graph.

func NewBuilder

func NewBuilder[T tensor.Numeric](engine compute.Engine[T]) *Builder[T]

NewBuilder creates a new graph builder.

func (*Builder[T]) AddNode

func (b *Builder[T]) AddNode(node Node[T], inputs ...Node[T]) Node[T]

AddNode adds a new node to the graph with the given inputs.

func (*Builder[T]) Build

func (b *Builder[T]) Build(outputNode Node[T]) (*Graph[T], error)

Build constructs the final graph.

func (*Builder[T]) Input

func (b *Builder[T]) Input(shape []int) Node[T]

Input creates a new input node.

func (*Builder[T]) Parameters

func (b *Builder[T]) Parameters() []*Parameter[T]

Parameters returns all the trainable parameters in the graph.

type Graph

type Graph[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Graph represents a computation graph with a defined execution order.

func (*Graph[T]) Backward

func (g *Graph[T]) Backward(ctx context.Context, mode types.BackwardMode, initialGradient *tensor.TensorNumeric[T]) error

Backward executes the backward pass of the entire graph.

func (*Graph[T]) Forward

func (g *Graph[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward executes the forward pass of the entire graph.

func (*Graph[T]) Inputs added in v0.3.0

func (g *Graph[T]) Inputs() []Node[T]

Inputs returns the input nodes of the graph.

func (*Graph[T]) Output added in v0.3.0

func (g *Graph[T]) Output() Node[T]

Output returns the output node of the graph.

func (*Graph[T]) Parameters

func (g *Graph[T]) Parameters() []*Parameter[T]

Parameters returns all the trainable parameters in the graph.

type NoParameters

type NoParameters[T tensor.Numeric] struct{}

NoParameters is a utility type for nodes that have no trainable parameters.

func (*NoParameters[T]) Parameters

func (n *NoParameters[T]) Parameters() []*Parameter[T]

Parameters returns an empty slice of parameters.

type Node

type Node[T tensor.Numeric] interface {
	// OpType returns the operation type of the node, e.g., "ReLU", "Dense".
	OpType() string

	// Attributes returns a map of the node's non-tensor attributes.
	Attributes() map[string]interface{}

	// Forward computes the output of the node given the inputs.
	Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Backward computes the gradients of the node with respect to its inputs.
	Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

	// Parameters returns the trainable parameters of the node.
	Parameters() []*Parameter[T]

	// OutputShape returns the shape of the output tensor.
	OutputShape() []int
}

Node represents a node in the computation graph.

type Parameter

type Parameter[T tensor.Numeric] struct {
	Name     string
	Value    *tensor.TensorNumeric[T]
	Gradient *tensor.TensorNumeric[T]
}

Parameter represents a trainable parameter in the graph.

func NewParameter

func NewParameter[T tensor.Numeric](name string, value *tensor.TensorNumeric[T], newTensorFn func([]int, []T) (*tensor.TensorNumeric[T], error)) (*Parameter[T], error)

NewParameter creates a new parameter.

func (*Parameter[T]) AddGradient

func (p *Parameter[T]) AddGradient(grad *tensor.TensorNumeric[T]) error

AddGradient adds the given gradient to the parameter's gradient.

func (*Parameter[T]) ClearGradient

func (p *Parameter[T]) ClearGradient()

ClearGradient resets the parameter's gradient to zero.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL