gather

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 25, 2025 License: Apache-2.0 Imports: 8 Imported by: 0

Documentation

Overview

Package gather provides the Gather layer for the Zerfoo ML framework.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func BuildGather

func BuildGather[T tensor.Numeric](
	engine compute.Engine[T],
	_ numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	_ map[string]interface{},
) (graph.Node[T], error)

BuildGather constructs a Gather layer. It attempts to resolve embedding weights from common parameter naming patterns; otherwise a minimal dummy tensor is used.

Types

type Gather

type Gather[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Gather is a layer that gathers slices from a tensor.

func New

func New[T tensor.Numeric](engine compute.Engine[T]) *Gather[T]

New creates a new Gather layer.

func NewWithWeights

func NewWithWeights[T tensor.Numeric](engine compute.Engine[T], weights *tensor.TensorNumeric[T]) *Gather[T]

NewWithWeights creates a new Gather layer with embedded weights.

func (*Gather[T]) Attributes

func (g *Gather[T]) Attributes() map[string]interface{}

Attributes returns nil for the Gather layer.

func (*Gather[T]) Backward

func (g *Gather[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients for the Gather layer.

func (*Gather[T]) Forward

func (g *Gather[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the gather operation.

func (*Gather[T]) HasEmbeddedWeights

func (g *Gather[T]) HasEmbeddedWeights() bool

HasEmbeddedWeights returns true if this Gather layer has embedded weights.

func (*Gather[T]) OpType

func (g *Gather[T]) OpType() string

OpType returns the operation type of the Gather layer.

func (*Gather[T]) OutputShape

func (g *Gather[T]) OutputShape() []int

OutputShape returns the output shape of the Gather layer.

func (*Gather[T]) Parameters

func (g *Gather[T]) Parameters() []*graph.Parameter[T]

Parameters returns no trainable parameters for the Gather layer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL