Documentation
¶
Overview ¶
Package generate provides text generation utilities for LLMs.
This package implements sampling strategies and inference pipelines for autoregressive text generation.
Index ¶
- type GenerateConfig
- type GenerateResult
- type GeneratorOption
- type KVCache
- type LLMModel
- type Sampler
- type SamplingConfig
- type SpeculativeConfig
- type SpeculativeGenerator
- func (sg *SpeculativeGenerator) AcceptanceRate() float32
- func (sg *SpeculativeGenerator) ClearCaches()
- func (sg *SpeculativeGenerator) Generate(inputIDs []int32, maxTokens int) ([]int32, float32, error)
- func (sg *SpeculativeGenerator) SetCaches(draft, target KVCache)
- func (sg *SpeculativeGenerator) Stats() (drafted, accepted int, rate float32)
- type TextGenerator
- func (g *TextGenerator) Chat(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, ...) (string, error)
- func (g *TextGenerator) ChatStream(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, ...) (<-chan GenerateResult, error)
- func (g *TextGenerator) ClearCache()
- func (g *TextGenerator) Generate(prompt string, config GenerateConfig) (string, error)
- func (g *TextGenerator) GenerateStream(prompt string, config GenerateConfig) (<-chan GenerateResult, error)
- func (g *TextGenerator) SetCache(cache KVCache)
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type GenerateConfig ¶
type GenerateConfig struct {
// MaxTokens is the maximum number of tokens to generate.
MaxTokens int
// MinTokens is the minimum number of tokens before stopping.
MinTokens int
// StopStrings are strings that trigger stopping.
StopStrings []string
// StopTokens are token IDs that trigger stopping.
StopTokens []int32
// Stream enables streaming generation.
Stream bool
// EchoPrompt includes the prompt in output.
EchoPrompt bool
// Sampling is the sampling configuration.
Sampling SamplingConfig
}
GenerateConfig configures text generation.
func DefaultGenerateConfig ¶
func DefaultGenerateConfig() GenerateConfig
DefaultGenerateConfig returns sensible defaults for generation.
type GenerateResult ¶
type GenerateResult struct {
Token string // Decoded token text
TokenID int32 // Token ID
Done bool // Is generation complete
Reason string // Stop reason: "eos", "max_tokens", "stop_string", "stop_token"
Error error // Error if any
}
GenerateResult is a single result from streaming generation.
type GeneratorOption ¶
type GeneratorOption func(*generatorOptions)
GeneratorOption configures a TextGenerator.
func WithMaxSeqLen ¶
func WithMaxSeqLen(n int) GeneratorOption
WithMaxSeqLen sets the maximum sequence length.
type KVCache ¶
type KVCache interface {
Clear()
}
KVCache is an interface for key-value caches used in generation.
type LLMModel ¶
type LLMModel interface {
// Forward runs a forward pass and returns logits.
// Input shape: [batch, seq_len]
// Output shape: [batch, seq_len, vocab_size]
Forward(input *tensor.RawTensor, cache KVCache, startPos int) *tensor.RawTensor
// VocabSize returns the vocabulary size.
VocabSize() int
}
LLMModel is the interface for language models used in generation.
type Sampler ¶
type Sampler struct {
// contains filtered or unexported fields
}
Sampler samples tokens from logits using configurable strategies.
func NewSampler ¶
func NewSampler(config SamplingConfig) *Sampler
NewSampler creates a new sampler with the given configuration.
func (*Sampler) Sample ¶
Sample returns the next token ID from logits.
Parameters:
- logits: raw model output, shape [vocab_size] or [..., vocab_size]
- previousTokens: tokens generated so far (for repetition penalty)
The sampling process:
- Apply repetition penalty
- Apply temperature scaling
- Apply Top-K filtering
- Apply Top-P (nucleus) filtering
- Apply Min-P filtering
- Sample from distribution (or argmax if temperature=0)
type SamplingConfig ¶
type SamplingConfig struct {
// Temperature controls randomness. 0 = greedy, 1 = normal, >1 = more random.
Temperature float32
// TopK limits sampling to top K tokens. 0 = disabled.
TopK int
// TopP (nucleus sampling) limits to tokens with cumulative prob < P. 1.0 = disabled.
TopP float32
// MinP filters tokens with prob < max_prob * MinP. 0 = disabled.
MinP float32
// Repetition control
RepeatPenalty float32 // Penalty for repeated tokens. 1.0 = no penalty.
FrequencyPenalty float32 // Penalty based on frequency. 0 = disabled.
PresencePenalty float32 // Penalty for presence. 0 = disabled.
RepeatWindow int // Number of tokens to consider. 0 = all.
// Seed for reproducibility. -1 = random.
Seed int64
}
SamplingConfig configures the sampling strategy for text generation.
func DefaultSamplingConfig ¶
func DefaultSamplingConfig() SamplingConfig
DefaultSamplingConfig returns sensible defaults for text generation.
type SpeculativeConfig ¶ added in v0.7.0
type SpeculativeConfig struct {
// DraftModel is the small fast model for speculation.
DraftModel LLMModel
// TargetModel is the large accurate model for verification.
TargetModel LLMModel
// NumSpeculate is the number of tokens to speculate (default: 5).
NumSpeculate int
// Sampling is the sampling configuration.
Sampling SamplingConfig
}
SpeculativeConfig configures speculative decoding.
func DefaultSpeculativeConfig ¶ added in v0.7.0
func DefaultSpeculativeConfig(draftModel, targetModel LLMModel) SpeculativeConfig
DefaultSpeculativeConfig returns sensible defaults for speculative decoding.
type SpeculativeGenerator ¶ added in v0.7.0
type SpeculativeGenerator struct {
// contains filtered or unexported fields
}
SpeculativeGenerator generates text using speculative decoding.
Algorithm:
- Draft model generates K tokens quickly
- Target model verifies all K tokens in parallel (single forward pass)
- Accept matching tokens using modified rejection sampling
- Repeat from first rejected token
This provides speedup by:
- Draft model is faster (smaller)
- Target model processes K tokens in one forward pass (parallel)
- Only rejected tokens require sequential decoding
Example:
config := SpeculativeConfig{
DraftModel: smallModel, // Fast 1B model
TargetModel: largeModel, // Accurate 7B model
NumSpeculate: 5, // Speculate 5 tokens ahead
Sampling: DefaultSamplingConfig(),
}
generator := NewSpeculativeGenerator(config)
generator.SetCaches(draftCache, targetCache)
tokens, acceptRate, err := generator.Generate(inputIDs, maxTokens)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Generated %d tokens with %.1f%% acceptance rate\n",
len(tokens), acceptRate*100)
func NewSpeculativeGenerator ¶ added in v0.7.0
func NewSpeculativeGenerator(config SpeculativeConfig) *SpeculativeGenerator
NewSpeculativeGenerator creates a new speculative decoding generator.
func (*SpeculativeGenerator) AcceptanceRate ¶ added in v0.7.0
func (sg *SpeculativeGenerator) AcceptanceRate() float32
AcceptanceRate returns the current acceptance rate statistics.
func (*SpeculativeGenerator) ClearCaches ¶ added in v0.7.0
func (sg *SpeculativeGenerator) ClearCaches()
ClearCaches clears both KV caches.
func (*SpeculativeGenerator) Generate ¶ added in v0.7.0
func (sg *SpeculativeGenerator) Generate( inputIDs []int32, maxTokens int, ) ([]int32, float32, error)
Generate generates text using speculative decoding. Returns generated tokens and acceptance rate.
func (*SpeculativeGenerator) SetCaches ¶ added in v0.7.0
func (sg *SpeculativeGenerator) SetCaches(draft, target KVCache)
SetCaches sets KV caches for draft and target models.
func (*SpeculativeGenerator) Stats ¶ added in v0.7.0
func (sg *SpeculativeGenerator) Stats() (drafted, accepted int, rate float32)
Stats returns detailed generation statistics.
type TextGenerator ¶
type TextGenerator struct {
// contains filtered or unexported fields
}
TextGenerator generates text using an LLM.
func NewTextGenerator ¶
func NewTextGenerator( model LLMModel, tok tokenizer.Tokenizer, samplingConfig SamplingConfig, opts ...GeneratorOption, ) *TextGenerator
NewTextGenerator creates a new text generator.
func (*TextGenerator) Chat ¶
func (g *TextGenerator) Chat(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, config GenerateConfig) (string, error)
Chat generates a response for chat messages. Uses the provided ChatTemplate to format messages into a prompt.
func (*TextGenerator) ChatStream ¶
func (g *TextGenerator) ChatStream(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, config GenerateConfig) (<-chan GenerateResult, error)
ChatStream generates a streaming response for chat messages. Uses the provided ChatTemplate to format messages into a prompt.
func (*TextGenerator) ClearCache ¶
func (g *TextGenerator) ClearCache()
ClearCache clears the KV cache.
func (*TextGenerator) Generate ¶
func (g *TextGenerator) Generate(prompt string, config GenerateConfig) (string, error)
Generate generates text from a prompt.
func (*TextGenerator) GenerateStream ¶
func (g *TextGenerator) GenerateStream(prompt string, config GenerateConfig) (<-chan GenerateResult, error)
GenerateStream generates text and returns a channel of results.
func (*TextGenerator) SetCache ¶
func (g *TextGenerator) SetCache(cache KVCache)
SetCache sets the KV cache for generation.