Documentation
¶
Overview ¶
Package attention provides attention mechanisms for neural networks.
Package attention provides attention mechanisms for neural networks.
Index ¶
- func BuildGroupQueryAttention[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], name string, ...) (graph.Node[T], error)
- func QKNorm[T tensor.Numeric](_ context.Context, _ compute.Engine[T], q, k *tensor.TensorNumeric[T], ...) (qNorm, kNorm *tensor.TensorNumeric[T], err error)
- type AttentionHead
- func (ah *AttentionHead[T]) Attributes() map[string]interface{}
- func (ah *AttentionHead[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (ah *AttentionHead[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (ah *AttentionHead[T]) OpType() string
- func (ah *AttentionHead[T]) OutputShape() []int
- func (ah *AttentionHead[T]) Parameters() []*graph.Parameter[T]
- type AttentionHeadOption
- type AttentionHeadOptions
- type GQAOption
- type GQAOptions
- type GlobalAttention
- func (ga *GlobalAttention[T]) Attributes() map[string]interface{}
- func (ga *GlobalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (ga *GlobalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (ga *GlobalAttention[T]) OpType() string
- func (ga *GlobalAttention[T]) OutputShape() []int
- func (ga *GlobalAttention[T]) Parameters() []*graph.Parameter[T]
- func (ga *GlobalAttention[T]) ScaleRope(ctx context.Context, factor float64) error
- type GlobalAttentionOption
- type GlobalAttentionOptions
- type GroupedQueryAttention
- func (gqa *GroupedQueryAttention[T]) Attributes() map[string]interface{}
- func (gqa *GroupedQueryAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (gqa *GroupedQueryAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (gqa *GroupedQueryAttention[T]) OpType() string
- func (gqa *GroupedQueryAttention[T]) OutputShape() []int
- func (gqa *GroupedQueryAttention[T]) Parameters() []*graph.Parameter[T]
- func (gqa *GroupedQueryAttention[T]) ScaleRope(ctx context.Context, factor float64) error
- type LocalAttention
- func (la *LocalAttention[T]) Attributes() map[string]interface{}
- func (la *LocalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (la *LocalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (la *LocalAttention[T]) OpType() string
- func (la *LocalAttention[T]) OutputShape() []int
- func (la *LocalAttention[T]) Parameters() []*graph.Parameter[T]
- type LocalAttentionOption
- type LocalAttentionOptions
- type RopeScaler
- type ScaledDotProductAttention
- type ScaledDotProductAttentionOption
- type ScaledDotProductAttentionOptions
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func BuildGroupQueryAttention ¶ added in v0.3.0
func BuildGroupQueryAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], name string, params map[string]*graph.Parameter[T], attributes map[string]interface{}, ) (graph.Node[T], error)
BuildGroupQueryAttention constructs a GroupedQueryAttention node for the model builder. Unused parameters are accepted to satisfy the common builder signature.
func QKNorm ¶ added in v0.3.0
func QKNorm[T tensor.Numeric](_ context.Context, _ compute.Engine[T], q, k *tensor.TensorNumeric[T], epsilon float64) (qNorm, kNorm *tensor.TensorNumeric[T], err error)
QKNorm applies a form of normalization to Query (Q) and Key (K) tensors to stabilize attention score scales, similar to RMSNorm. It normalizes Q and K independently by their respective RMS values.
Types ¶
type AttentionHead ¶ added in v0.3.0
AttentionHead implements a single attention head, including linear projections for Query, Key, and Value, followed by scaled dot-product attention.
func NewAttentionHead ¶ added in v0.3.0
func NewAttentionHead[T tensor.Numeric](engine compute.Engine[T], inputDim, headDim int, opts ...AttentionHeadOption[T]) *AttentionHead[T]
NewAttentionHead creates a new AttentionHead instance. inputDim is the dimension of the input features. headDim is the dimension of the query, key, and value vectors for this head.
func (*AttentionHead[T]) Attributes ¶ added in v0.3.0
func (ah *AttentionHead[T]) Attributes() map[string]interface{}
Attributes returns the attributes for the AttentionHead.
func (*AttentionHead[T]) Backward ¶ added in v0.3.0
func (ah *AttentionHead[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for the AttentionHead. dOut has shape (batch, seq_len, head_dim). inputs[0] has shape (batch, seq_len, input_dim).
func (*AttentionHead[T]) Forward ¶ added in v0.3.0
func (ah *AttentionHead[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the output of the attention head. input is expected to be a 3D tensor (batch_size, seq_len, input_dim).
func (*AttentionHead[T]) OpType ¶ added in v0.3.0
func (ah *AttentionHead[T]) OpType() string
OpType returns the operation type of the AttentionHead.
func (*AttentionHead[T]) OutputShape ¶ added in v0.3.0
func (ah *AttentionHead[T]) OutputShape() []int
OutputShape returns the output shape of the AttentionHead. It assumes the input shape is (batch_size, seq_len, input_dim). The output shape will be (batch_size, seq_len, head_dim).
func (*AttentionHead[T]) Parameters ¶ added in v0.3.0
func (ah *AttentionHead[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters of the AttentionHead.
type AttentionHeadOption ¶ added in v0.3.0
type AttentionHeadOption[T tensor.Numeric] func(*AttentionHeadOptions[T])
AttentionHeadOption applies an option to AttentionHeadOptions.
type AttentionHeadOptions ¶ added in v0.3.0
AttentionHeadOptions holds configuration options for AttentionHead.
type GQAOption ¶ added in v0.3.0
type GQAOption[T tensor.Numeric] func(*GQAOptions[T])
GQAOption is a function that applies an option to GQAOptions.
func WithMaxSeqLen ¶ added in v0.3.0
WithMaxSeqLen sets the maximum sequence length for Rotary Positional Embeddings.
type GQAOptions ¶ added in v0.3.0
GQAOptions holds configuration options for the GroupedQueryAttention layer.
type GlobalAttention ¶ added in v0.3.0
GlobalAttention wraps GroupedQueryAttention to provide a global attention interface.
func NewGlobalAttention ¶ added in v0.3.0
func NewGlobalAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads int, options ...GlobalAttentionOption, ) (*GlobalAttention[T], error)
NewGlobalAttention creates a new GlobalAttention layer.
Parameters: - engine: compute engine for tensor operations - ops: arithmetic operations for the numeric type - modelDim: model dimension - numQueryHeads: number of query heads - numKeyValueHeads: number of key/value heads - options: functional options for configuration
Default values: - base: 10000.0 - maxSeqLen: 2048.
func NewGlobalAttentionFromParams ¶ added in v0.3.0
func NewGlobalAttentionFromParams[T tensor.Numeric](gqa *GroupedQueryAttention[T]) *GlobalAttention[T]
NewGlobalAttentionFromParams creates a new GlobalAttention layer from an existing GroupedQueryAttention layer.
func (*GlobalAttention[T]) Attributes ¶ added in v0.3.0
func (ga *GlobalAttention[T]) Attributes() map[string]interface{}
Attributes returns the attributes.
func (*GlobalAttention[T]) Backward ¶ added in v0.3.0
func (ga *GlobalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward delegates the backward pass to the wrapped GroupedQueryAttention.
func (*GlobalAttention[T]) Forward ¶ added in v0.3.0
func (ga *GlobalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the forward pass of the GlobalAttention layer.
func (*GlobalAttention[T]) OpType ¶ added in v0.3.0
func (ga *GlobalAttention[T]) OpType() string
OpType returns the operation type.
func (*GlobalAttention[T]) OutputShape ¶ added in v0.3.0
func (ga *GlobalAttention[T]) OutputShape() []int
OutputShape returns the output shape of the GlobalAttention layer.
func (*GlobalAttention[T]) Parameters ¶ added in v0.3.0
func (ga *GlobalAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns the parameters of the GlobalAttention layer.
type GlobalAttentionOption ¶ added in v0.3.0
type GlobalAttentionOption func(*GlobalAttentionOptions)
GlobalAttentionOption is a function that configures GlobalAttentionOptions.
func WithGlobalAttentionBase ¶ added in v0.3.0
func WithGlobalAttentionBase(base float64) GlobalAttentionOption
WithGlobalAttentionBase sets the base (theta) parameter for rotary positional embeddings.
func WithGlobalAttentionMaxSeqLen ¶ added in v0.3.0
func WithGlobalAttentionMaxSeqLen(maxSeqLen int) GlobalAttentionOption
WithGlobalAttentionMaxSeqLen sets the maximum sequence length.
type GlobalAttentionOptions ¶ added in v0.3.0
GlobalAttentionOptions holds configuration options for GlobalAttention layer.
type GroupedQueryAttention ¶
GroupedQueryAttention implements grouped query attention mechanism.
func NewGroupedQueryAttention ¶
func NewGroupedQueryAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads int, opts ...GQAOption[T], ) (*GroupedQueryAttention[T], error)
NewGroupedQueryAttention creates a new GroupedQueryAttention layer. modelDim: The dimension of the input and output of the block (d_model). numQueryHeads: The number of query heads. numKeyValueHeads: The number of key/value heads.
func NewGroupedQueryAttentionFromParams ¶ added in v0.3.0
func NewGroupedQueryAttentionFromParams[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads int, wq, wk, wv, wo *core.Dense[T], rope *embeddings.RotaryPositionalEmbedding[T], ) (*GroupedQueryAttention[T], error)
NewGroupedQueryAttentionFromParams creates a new GroupedQueryAttention layer from existing parameters.
func (*GroupedQueryAttention[T]) Attributes ¶ added in v0.3.0
func (gqa *GroupedQueryAttention[T]) Attributes() map[string]interface{}
Attributes returns the attributes.
func (*GroupedQueryAttention[T]) Backward ¶
func (gqa *GroupedQueryAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for GroupedQueryAttention.
func (*GroupedQueryAttention[T]) Forward ¶
func (gqa *GroupedQueryAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the grouped query attention.
func (*GroupedQueryAttention[T]) OpType ¶ added in v0.3.0
func (gqa *GroupedQueryAttention[T]) OpType() string
OpType returns the operation type.
func (*GroupedQueryAttention[T]) OutputShape ¶
func (gqa *GroupedQueryAttention[T]) OutputShape() []int
OutputShape returns the output shape of the GroupedQueryAttention.
func (*GroupedQueryAttention[T]) Parameters ¶
func (gqa *GroupedQueryAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns the parameters of the GroupedQueryAttention layer.
type LocalAttention ¶ added in v0.3.0
LocalAttention implements a local, sliding-window self-attention mechanism.
func NewLocalAttention ¶ added in v0.3.0
func NewLocalAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads, windowSize int, opts ...LocalAttentionOption[T], ) (*LocalAttention[T], error)
NewLocalAttention creates a new LocalAttention layer.
func (*LocalAttention[T]) Attributes ¶ added in v0.3.0
func (la *LocalAttention[T]) Attributes() map[string]interface{}
Attributes returns the attributes of the LocalAttention layer.
func (*LocalAttention[T]) Backward ¶ added in v0.3.0
func (la *LocalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward delegates the backward pass to the wrapped GroupedQueryAttention.
func (*LocalAttention[T]) Forward ¶ added in v0.3.0
func (la *LocalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the forward pass of the LocalAttention layer.
func (*LocalAttention[T]) OpType ¶ added in v0.3.0
func (la *LocalAttention[T]) OpType() string
OpType returns the operation type of the LocalAttention layer.
func (*LocalAttention[T]) OutputShape ¶ added in v0.3.0
func (la *LocalAttention[T]) OutputShape() []int
OutputShape returns the output shape of the LocalAttention layer.
func (*LocalAttention[T]) Parameters ¶ added in v0.3.0
func (la *LocalAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns the parameters of the LocalAttention layer.
type LocalAttentionOption ¶ added in v0.3.0
type LocalAttentionOption[T tensor.Numeric] func(*LocalAttentionOptions[T])
LocalAttentionOption is a function that applies an option to LocalAttentionOptions.
func WithLocalMaxSeqLen ¶ added in v0.3.0
func WithLocalMaxSeqLen[T tensor.Numeric](maxSeqLen int) LocalAttentionOption[T]
WithLocalMaxSeqLen sets the maximum sequence length for Rotary Positional Embeddings.
maxSeqLen: The maximum sequence length for Rotary Positional Embeddings.
func WithLocalRopeBase ¶ added in v0.3.0
func WithLocalRopeBase[T tensor.Numeric](base float64) LocalAttentionOption[T]
WithLocalRopeBase sets the base for Rotary Positional Embeddings.
base: The base for Rotary Positional Embeddings.
type LocalAttentionOptions ¶ added in v0.3.0
LocalAttentionOptions holds configuration options for the LocalAttention layer.
type RopeScaler ¶ added in v0.3.0
type RopeScaler[T tensor.Numeric] interface { ScaleRope(ctx context.Context, factor float64) error }
RopeScaler is an interface for layers that support scaling of RoPE.
type ScaledDotProductAttention ¶
type ScaledDotProductAttention[T tensor.Numeric] struct { // contains filtered or unexported fields }
ScaledDotProductAttention implements the scaled dot-product attention mechanism.
func NewScaledDotProductAttention ¶
func NewScaledDotProductAttention[T tensor.Numeric](engine compute.Engine[T], headDim int, opts ...ScaledDotProductAttentionOption[T]) *ScaledDotProductAttention[T]
NewScaledDotProductAttention creates a new ScaledDotProductAttention layer.
func (*ScaledDotProductAttention[T]) Backward ¶
func (sdpa *ScaledDotProductAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut, _, _, _ *tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for ScaledDotProductAttention. dOut is the gradient from the subsequent layer.
func (*ScaledDotProductAttention[T]) Forward ¶
func (sdpa *ScaledDotProductAttention[T]) Forward(ctx context.Context, q, k, v, mask *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the scaled dot-product attention. Q, K, V are expected to be 3D tensors (batch_size, seq_len, head_dim). mask is an optional 4D tensor (batch_size, num_heads, seq_len_q, seq_len_k).
type ScaledDotProductAttentionOption ¶ added in v0.3.0
type ScaledDotProductAttentionOption[T tensor.Numeric] func(*ScaledDotProductAttentionOptions[T])
ScaledDotProductAttentionOption applies an option to ScaledDotProductAttentionOptions.
type ScaledDotProductAttentionOptions ¶ added in v0.3.0
ScaledDotProductAttentionOptions holds configuration options for ScaledDotProductAttention.