Documentation
¶
Overview ¶
Package gather provides the Gather layer for the Zerfoo ML framework.
Index ¶
- func BuildGather[T tensor.Numeric](engine compute.Engine[T], _ numeric.Arithmetic[T], name string, ...) (graph.Node[T], error)
- type Gather
- func (g *Gather[T]) Attributes() map[string]interface{}
- func (g *Gather[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (g *Gather[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (g *Gather[T]) HasEmbeddedWeights() bool
- func (g *Gather[T]) OpType() string
- func (g *Gather[T]) OutputShape() []int
- func (g *Gather[T]) Parameters() []*graph.Parameter[T]
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func BuildGather ¶
func BuildGather[T tensor.Numeric]( engine compute.Engine[T], _ numeric.Arithmetic[T], name string, params map[string]*graph.Parameter[T], _ map[string]interface{}, ) (graph.Node[T], error)
BuildGather constructs a Gather layer. It attempts to resolve embedding weights from common parameter naming patterns; otherwise a minimal dummy tensor is used.
Types ¶
type Gather ¶
Gather is a layer that gathers slices from a tensor.
func NewWithWeights ¶
func NewWithWeights[T tensor.Numeric](engine compute.Engine[T], weights *tensor.TensorNumeric[T]) *Gather[T]
NewWithWeights creates a new Gather layer with embedded weights.
func (*Gather[T]) Attributes ¶
Attributes returns nil for the Gather layer.
func (*Gather[T]) Backward ¶
func (g *Gather[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for the Gather layer.
func (*Gather[T]) Forward ¶
func (g *Gather[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the gather operation.
func (*Gather[T]) HasEmbeddedWeights ¶
HasEmbeddedWeights returns true if this Gather layer has embedded weights.
func (*Gather[T]) OutputShape ¶
OutputShape returns the output shape of the Gather layer.
func (*Gather[T]) Parameters ¶
Parameters returns no trainable parameters for the Gather layer.