Documentation
¶
Index ¶
- Constants
- Variables
- type Cache
- type Causal
- func (c *Causal) CanResume(seq int, pos int32) bool
- func (c *Causal) Close()
- func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32)
- func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
- func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
- func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor)
- func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error
- func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions)
- func (c *Causal) SetConfig(config ml.CacheConfig)
- func (c *Causal) SetLayer(layer int)
- func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
- type CausalOptions
- type CheckpointCache
- type EncoderCache
- func (c *EncoderCache) CanResume(seq int, pos int32) bool
- func (c *EncoderCache) Close()
- func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32)
- func (c *EncoderCache) EncoderCached() bool
- func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
- func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
- func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor)
- func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error
- func (c *EncoderCache) SetConfig(config ml.CacheConfig)
- func (c *EncoderCache) SetLayer(layer int)
- func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
- type Recurrent
- func (c *Recurrent) CanResume(seq int, pos int32) bool
- func (c *Recurrent) Close()
- func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error)
- func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32)
- func (c *Recurrent) EnsureWritable(ctx ml.Context) error
- func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
- func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
- func (c *Recurrent) IsSupportedForBatch() bool
- func (c *Recurrent) NumSeqs() int
- func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool)
- func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor)
- func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error)
- func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error)
- func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error
- func (c *Recurrent) SeqTokens() int
- func (c *Recurrent) Seqs() []int
- func (c *Recurrent) SetConfig(config ml.CacheConfig)
- func (c *Recurrent) SetLayer(layer int)
- func (c *Recurrent) SlotsTensor() ml.Tensor
- func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
- func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor)
- func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor)
- type RecurrentConfig
- type WrapperCache
- func (c *WrapperCache) CanResume(seq int, pos int32) bool
- func (c *WrapperCache) Close()
- func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32)
- func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
- func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
- func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor)
- func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error
- func (c *WrapperCache) SetConfig(config ml.CacheConfig)
- func (c *WrapperCache) SetLayer(layer int)
- func (c *WrapperCache) SetLayerType(layerType int)
- func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
- func (c *WrapperCache) UnderlyingCache() Cache
Constants ¶
const ( DefaultCheckpointCount = 24 DefaultCheckpointMinPos = int32(16) DefaultCheckpointInterval = int32(1664) )
Variables ¶
var ( ErrKvCacheFull = errors.New("could not find a kv cache slot") ErrNotSupported = errors.New("model does not support operation") )
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
Functions ¶
This section is empty.
Types ¶
type Cache ¶
type Cache interface {
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
type Causal ¶
Causal cache stores K and V tensors according to their position in the sequence. Returns the history and a mask for attending to past tokens
The tensors are of shape embed dim, kv heads, batch size The mask is of shape history size, batch size
func NewCausalCache ¶
func NewCausalCache(shift shiftFn) *Causal
func NewChunkedAttentionCache ¶ added in v0.6.7
func NewSWACache ¶
func NewSWAMemCache ¶ added in v0.11.0
func (*Causal) CopyPrefix ¶
func (*Causal) SetCausal ¶ added in v0.6.0
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions)
SetCausal disables causal mask generation for a particular range of indicies in the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (*Causal) SetConfig ¶ added in v0.5.13
func (c *Causal) SetConfig(config ml.CacheConfig)
type CausalOptions ¶ added in v0.6.0
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
type CheckpointCache ¶ added in v0.15.5
CheckpointCache optionally supports restoring recurrent state to a prior position to avoid full prompt reprocessing when a prefix mismatch occurs. The returned position is the number of tokens that can be kept (prefix length).
type EncoderCache ¶
type EncoderCache struct {
// contains filtered or unexported fields
}
Encoder cache stores K and V tensors that are position independent
The tensors can be of any shape and will be returned as they were stored The mask is currently always nil
Not currently safe for multiple sequences
func NewEncoderCache ¶
func NewEncoderCache() *EncoderCache
func (*EncoderCache) CanResume ¶ added in v0.6.4
func (c *EncoderCache) CanResume(seq int, pos int32) bool
func (*EncoderCache) Close ¶
func (c *EncoderCache) Close()
func (*EncoderCache) CopyPrefix ¶
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32)
func (*EncoderCache) EncoderCached ¶
func (c *EncoderCache) EncoderCached() bool
func (*EncoderCache) Remove ¶
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error
func (*EncoderCache) SetConfig ¶ added in v0.5.13
func (c *EncoderCache) SetConfig(config ml.CacheConfig)
func (*EncoderCache) SetLayer ¶
func (c *EncoderCache) SetLayer(layer int)
func (*EncoderCache) StartForward ¶
type Recurrent ¶ added in v0.17.1
type Recurrent struct {
// contains filtered or unexported fields
}
Cache stores: - a standard causal KV cache - per-sequence conv state for recurrent operators - per-sequence recurrent state for recurrent operators
Conv state shape (per layer, per sequence): [convDim, convChannels] Recurrent state shape (per layer, per sequence): [recurrentStateSize]
func NewRecurrentCache ¶ added in v0.17.1
func NewRecurrentCache(config RecurrentConfig) *Recurrent
func (*Recurrent) ConvState ¶ added in v0.17.1
ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
func (*Recurrent) CopyPrefix ¶ added in v0.17.1
func (*Recurrent) EnsureWritable ¶ added in v0.17.1
EnsureWritable ensures sequences have private slots (copy-on-write).
func (*Recurrent) IsSupportedForBatch ¶ added in v0.17.1
IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
func (*Recurrent) PrepareRestore ¶ added in v0.17.1
func (*Recurrent) RecurrentState ¶ added in v0.17.1
RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
func (*Recurrent) RecurrentState4D ¶ added in v0.17.1
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error)
RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
func (*Recurrent) Seqs ¶ added in v0.17.1
Seqs returns the ordered unique sequences for the current forward pass.
func (*Recurrent) SetConfig ¶ added in v0.17.1
func (c *Recurrent) SetConfig(config ml.CacheConfig)
func (*Recurrent) SlotsTensor ¶ added in v0.17.1
func (*Recurrent) StartForward ¶ added in v0.17.1
func (*Recurrent) UpdateConvState ¶ added in v0.17.1
UpdateConvState writes new conv state for current batch sequences.
type RecurrentConfig ¶ added in v0.17.1
type RecurrentConfig struct {
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
ConvDim int
ConvChannels int
RecurrentStateSize int
CheckpointLogPrefix string
}
Config configures a shared hybrid recurrent cache.
type WrapperCache ¶
type WrapperCache struct {
// contains filtered or unexported fields
}
Wrapper cache is a container for multiple types of caches, such as for the encoding and decoding portions of a model.
func NewWrapperCache ¶
func NewWrapperCache(caches ...Cache) *WrapperCache
func (*WrapperCache) CanResume ¶ added in v0.6.4
func (c *WrapperCache) CanResume(seq int, pos int32) bool
func (*WrapperCache) Close ¶
func (c *WrapperCache) Close()
func (*WrapperCache) CopyPrefix ¶
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32)
func (*WrapperCache) Remove ¶
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error
func (*WrapperCache) SetConfig ¶ added in v0.5.13
func (c *WrapperCache) SetConfig(config ml.CacheConfig)
func (*WrapperCache) SetLayer ¶
func (c *WrapperCache) SetLayer(layer int)
func (*WrapperCache) SetLayerType ¶
func (c *WrapperCache) SetLayerType(layerType int)
func (*WrapperCache) StartForward ¶
func (*WrapperCache) UnderlyingCache ¶
func (c *WrapperCache) UnderlyingCache() Cache