kvcache

package
v0.15.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 1, 2026 License: MIT Imports: 6 Imported by: 11

Documentation

Index

Constants

This section is empty.

Variables

View Source
var (
	ErrKvCacheFull  = errors.New("could not find a kv cache slot")
	ErrNotSupported = errors.New("model does not support operation")
)

Functions

This section is empty.

Types

type Cache

type Cache interface {

	// SetLayer sets the active layer of the cache
	SetLayer(layer int)

	// Get returns the history of key and value tensors plus a mask
	//
	// The shape of the tensors is documented in the specific
	// cache implementation used.
	Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)

	// Put stores a batch of key and value in the cache
	//
	// The shape of the tensors is documented in the specific
	// cache implementation used.
	Put(ctx ml.Context, key, value ml.Tensor)

	// SetConfig controls optimizations (mostly backend-specific) that may transform
	// the output of the cache to work better with specific kernels. If not called,
	// the backend settings will be used. This works well when calling Attention.
	//
	// The config can be overridden by models, especially if they require vanilla
	// output when implementing their own version of attention. To do this, pass
	// an empty ml.CacheConfig.
	//
	// Most models will not need to use this.
	SetConfig(ml.CacheConfig)

	// Init sets up runtime parameters.
	// backend: Used to allocate cache data storage and execute management operations (such as defrag)
	// dtype: The data type for storing cache entries
	// maxSequences: The maximum number of sequences stored in the cache - across all batches
	// capacity: The number of cache entries to store, per sequence
	// maxBatch: The maximum number of tokens that can occur in a single batch
	Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)

	// Close closes the cache and frees resources associated with it
	Close()

	// StartForward is called before the start of the model's forward pass.
	// For each token in the coming batch, there must be a corresponding
	// entry in positions and seqs. reserve is to preallocate memory
	// without actually storing data in the cache.
	StartForward(ctx ml.Context, batch input.Batch, reserve bool) error

	// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
	CopyPrefix(srcSeq, dstSeq int, len int32)

	// CanResume returns true if the cache can continue with the next token at
	// the given position and sequence. Assumes that the caller has already
	// verified the contents of the cache.
	CanResume(seq int, pos int32) bool

	// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
	// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
	//
	// If an error occurs, the entire context for the sequence should be
	// removed by calling Remove(seq, 0, math.MaxInt32)
	Remove(seq int, beginIndex, endIndex int32) error
}

type Causal

type Causal struct {
	DType ml.DType
	// contains filtered or unexported fields
}

Causal cache stores K and V tensors according to their position in the sequence. Returns the history and a mask for attending to past tokens

The tensors are of shape embed dim, kv heads, batch size The mask is of shape history size, batch size

func NewCausalCache

func NewCausalCache(shift shiftFn) *Causal

func NewChunkedAttentionCache added in v0.6.7

func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal

func NewSWACache

func NewSWACache(windowSize int32, shift shiftFn) *Causal

func NewSWAMemCache added in v0.11.0

func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal

func (*Causal) CanResume added in v0.6.4

func (c *Causal) CanResume(seq int, pos int32) bool

func (*Causal) Close

func (c *Causal) Close()

func (*Causal) CopyPrefix

func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32)

func (*Causal) Get

func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)

func (*Causal) Init

func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)

func (*Causal) Put

func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor)

func (*Causal) Remove

func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error

func (*Causal) SetCausal added in v0.6.0

func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions)

SetCausal disables causal mask generation for a particular range of indicies in the current batch for subsequent calls to Get. The state resets for the next forward pass.

func (*Causal) SetConfig added in v0.5.13

func (c *Causal) SetConfig(config ml.CacheConfig)

func (*Causal) SetLayer

func (c *Causal) SetLayer(layer int)

func (*Causal) StartForward

func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error

type CausalOptions added in v0.6.0

type CausalOptions struct {
	// Enabled controls whether the causal mask is generated for a particular index in a batch
	Except []int
}

type EncoderCache

type EncoderCache struct {
	// contains filtered or unexported fields
}

Encoder cache stores K and V tensors that are position independent

The tensors can be of any shape and will be returned as they were stored The mask is currently always nil

Not currently safe for multiple sequences

func NewEncoderCache

func NewEncoderCache() *EncoderCache

func (*EncoderCache) CanResume added in v0.6.4

func (c *EncoderCache) CanResume(seq int, pos int32) bool

func (*EncoderCache) Close

func (c *EncoderCache) Close()

func (*EncoderCache) CopyPrefix

func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32)

func (*EncoderCache) EncoderCached

func (c *EncoderCache) EncoderCached() bool

func (*EncoderCache) Get

func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)

func (*EncoderCache) Init

func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)

func (*EncoderCache) Put

func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor)

func (*EncoderCache) Remove

func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error

func (*EncoderCache) SetConfig added in v0.5.13

func (c *EncoderCache) SetConfig(config ml.CacheConfig)

func (*EncoderCache) SetLayer

func (c *EncoderCache) SetLayer(layer int)

func (*EncoderCache) StartForward

func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error

type WrapperCache

type WrapperCache struct {
	// contains filtered or unexported fields
}

Wrapper cache is a container for multiple types of caches, such as for the encoding and decoding portions of a model.

func NewWrapperCache

func NewWrapperCache(caches ...Cache) *WrapperCache

func (*WrapperCache) CanResume added in v0.6.4

func (c *WrapperCache) CanResume(seq int, pos int32) bool

func (*WrapperCache) Close

func (c *WrapperCache) Close()

func (*WrapperCache) CopyPrefix

func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32)

func (*WrapperCache) Get

func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)

func (*WrapperCache) Init

func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)

func (*WrapperCache) Put

func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor)

func (*WrapperCache) Remove

func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error

func (*WrapperCache) SetConfig added in v0.5.13

func (c *WrapperCache) SetConfig(config ml.CacheConfig)

func (*WrapperCache) SetLayer

func (c *WrapperCache) SetLayer(layer int)

func (*WrapperCache) SetLayerType

func (c *WrapperCache) SetLayerType(layerType int)

func (*WrapperCache) StartForward

func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error

func (*WrapperCache) UnderlyingCache

func (c *WrapperCache) UnderlyingCache() Cache

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL