generate

package
v0.7.7 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 6, 2026 License: Apache-2.0 Imports: 7 Imported by: 0

Documentation

Overview

Package generate provides text generation utilities for LLMs.

This package implements sampling strategies and inference pipelines for autoregressive text generation.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type GenerateConfig

type GenerateConfig struct {
	// MaxTokens is the maximum number of tokens to generate.
	MaxTokens int

	// MinTokens is the minimum number of tokens before stopping.
	MinTokens int

	// StopStrings are strings that trigger stopping.
	StopStrings []string

	// StopTokens are token IDs that trigger stopping.
	StopTokens []int32

	// Stream enables streaming generation.
	Stream bool

	// EchoPrompt includes the prompt in output.
	EchoPrompt bool

	// Sampling is the sampling configuration.
	Sampling SamplingConfig
}

GenerateConfig configures text generation.

func DefaultGenerateConfig

func DefaultGenerateConfig() GenerateConfig

DefaultGenerateConfig returns sensible defaults for generation.

type GenerateResult

type GenerateResult struct {
	Token   string // Decoded token text
	TokenID int32  // Token ID
	Done    bool   // Is generation complete
	Reason  string // Stop reason: "eos", "max_tokens", "stop_string", "stop_token"
	Error   error  // Error if any
}

GenerateResult is a single result from streaming generation.

type GeneratorOption

type GeneratorOption func(*generatorOptions)

GeneratorOption configures a TextGenerator.

func WithMaxSeqLen

func WithMaxSeqLen(n int) GeneratorOption

WithMaxSeqLen sets the maximum sequence length.

type KVCache

type KVCache interface {
	Clear()
}

KVCache is an interface for key-value caches used in generation.

type LLMModel

type LLMModel interface {
	// Forward runs a forward pass and returns logits.
	// Input shape: [batch, seq_len]
	// Output shape: [batch, seq_len, vocab_size]
	Forward(input *tensor.RawTensor, cache KVCache, startPos int) *tensor.RawTensor

	// VocabSize returns the vocabulary size.
	VocabSize() int
}

LLMModel is the interface for language models used in generation.

type Sampler

type Sampler struct {
	// contains filtered or unexported fields
}

Sampler samples tokens from logits using configurable strategies.

func NewSampler

func NewSampler(config SamplingConfig) *Sampler

NewSampler creates a new sampler with the given configuration.

func (*Sampler) Sample

func (s *Sampler) Sample(logits []float32, previousTokens []int32) int32

Sample returns the next token ID from logits.

Parameters:

  • logits: raw model output, shape [vocab_size] or [..., vocab_size]
  • previousTokens: tokens generated so far (for repetition penalty)

The sampling process:

  1. Apply repetition penalty
  2. Apply temperature scaling
  3. Apply Top-K filtering
  4. Apply Top-P (nucleus) filtering
  5. Apply Min-P filtering
  6. Sample from distribution (or argmax if temperature=0)

func (*Sampler) SampleTensor

func (s *Sampler) SampleTensor(logits *tensor.Tensor[float32, tensor.Backend], previousTokens []int32) int32

SampleTensor samples from a tensor of logits.

type SamplingConfig

type SamplingConfig struct {
	// Temperature controls randomness. 0 = greedy, 1 = normal, >1 = more random.
	Temperature float32

	// TopK limits sampling to top K tokens. 0 = disabled.
	TopK int

	// TopP (nucleus sampling) limits to tokens with cumulative prob < P. 1.0 = disabled.
	TopP float32

	// MinP filters tokens with prob < max_prob * MinP. 0 = disabled.
	MinP float32

	// Repetition control
	RepeatPenalty    float32 // Penalty for repeated tokens. 1.0 = no penalty.
	FrequencyPenalty float32 // Penalty based on frequency. 0 = disabled.
	PresencePenalty  float32 // Penalty for presence. 0 = disabled.
	RepeatWindow     int     // Number of tokens to consider. 0 = all.

	// Seed for reproducibility. -1 = random.
	Seed int64
}

SamplingConfig configures the sampling strategy for text generation.

func DefaultSamplingConfig

func DefaultSamplingConfig() SamplingConfig

DefaultSamplingConfig returns sensible defaults for text generation.

type SpeculativeConfig added in v0.7.0

type SpeculativeConfig struct {
	// DraftModel is the small fast model for speculation.
	DraftModel LLMModel

	// TargetModel is the large accurate model for verification.
	TargetModel LLMModel

	// NumSpeculate is the number of tokens to speculate (default: 5).
	NumSpeculate int

	// Sampling is the sampling configuration.
	Sampling SamplingConfig
}

SpeculativeConfig configures speculative decoding.

func DefaultSpeculativeConfig added in v0.7.0

func DefaultSpeculativeConfig(draftModel, targetModel LLMModel) SpeculativeConfig

DefaultSpeculativeConfig returns sensible defaults for speculative decoding.

type SpeculativeGenerator added in v0.7.0

type SpeculativeGenerator struct {
	// contains filtered or unexported fields
}

SpeculativeGenerator generates text using speculative decoding.

Algorithm:

  1. Draft model generates K tokens quickly
  2. Target model verifies all K tokens in parallel (single forward pass)
  3. Accept matching tokens using modified rejection sampling
  4. Repeat from first rejected token

This provides speedup by:

  • Draft model is faster (smaller)
  • Target model processes K tokens in one forward pass (parallel)
  • Only rejected tokens require sequential decoding

Example:

config := SpeculativeConfig{
    DraftModel:   smallModel,   // Fast 1B model
    TargetModel:  largeModel,   // Accurate 7B model
    NumSpeculate: 5,            // Speculate 5 tokens ahead
    Sampling:     DefaultSamplingConfig(),
}
generator := NewSpeculativeGenerator(config)
generator.SetCaches(draftCache, targetCache)

tokens, acceptRate, err := generator.Generate(inputIDs, maxTokens)
if err != nil {
    log.Fatal(err)
}
fmt.Printf("Generated %d tokens with %.1f%% acceptance rate\n",
    len(tokens), acceptRate*100)

func NewSpeculativeGenerator added in v0.7.0

func NewSpeculativeGenerator(config SpeculativeConfig) *SpeculativeGenerator

NewSpeculativeGenerator creates a new speculative decoding generator.

func (*SpeculativeGenerator) AcceptanceRate added in v0.7.0

func (sg *SpeculativeGenerator) AcceptanceRate() float32

AcceptanceRate returns the current acceptance rate statistics.

func (*SpeculativeGenerator) ClearCaches added in v0.7.0

func (sg *SpeculativeGenerator) ClearCaches()

ClearCaches clears both KV caches.

func (*SpeculativeGenerator) Generate added in v0.7.0

func (sg *SpeculativeGenerator) Generate(
	inputIDs []int32,
	maxTokens int,
) ([]int32, float32, error)

Generate generates text using speculative decoding. Returns generated tokens and acceptance rate.

func (*SpeculativeGenerator) SetCaches added in v0.7.0

func (sg *SpeculativeGenerator) SetCaches(draft, target KVCache)

SetCaches sets KV caches for draft and target models.

func (*SpeculativeGenerator) Stats added in v0.7.0

func (sg *SpeculativeGenerator) Stats() (drafted, accepted int, rate float32)

Stats returns detailed generation statistics.

type TextGenerator

type TextGenerator struct {
	// contains filtered or unexported fields
}

TextGenerator generates text using an LLM.

func NewTextGenerator

func NewTextGenerator(
	model LLMModel,
	tok tokenizer.Tokenizer,
	samplingConfig SamplingConfig,
	opts ...GeneratorOption,
) *TextGenerator

NewTextGenerator creates a new text generator.

func (*TextGenerator) Chat

func (g *TextGenerator) Chat(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, config GenerateConfig) (string, error)

Chat generates a response for chat messages. Uses the provided ChatTemplate to format messages into a prompt.

func (*TextGenerator) ChatStream

func (g *TextGenerator) ChatStream(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, config GenerateConfig) (<-chan GenerateResult, error)

ChatStream generates a streaming response for chat messages. Uses the provided ChatTemplate to format messages into a prompt.

func (*TextGenerator) ClearCache

func (g *TextGenerator) ClearCache()

ClearCache clears the KV cache.

func (*TextGenerator) Generate

func (g *TextGenerator) Generate(prompt string, config GenerateConfig) (string, error)

Generate generates text from a prompt.

func (*TextGenerator) GenerateStream

func (g *TextGenerator) GenerateStream(prompt string, config GenerateConfig) (<-chan GenerateResult, error)

GenerateStream generates text and returns a channel of results.

func (*TextGenerator) SetCache

func (g *TextGenerator) SetCache(cache KVCache)

SetCache sets the KV cache for generation.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL