generate

package
v0.5.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 3, 2025 License: Apache-2.0 Imports: 7 Imported by: 0

Documentation

Overview

Package generate provides text generation utilities for LLMs.

This package implements sampling strategies and inference pipelines for autoregressive text generation.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type GenerateConfig

type GenerateConfig struct {
	// MaxTokens is the maximum number of tokens to generate.
	MaxTokens int

	// MinTokens is the minimum number of tokens before stopping.
	MinTokens int

	// StopStrings are strings that trigger stopping.
	StopStrings []string

	// StopTokens are token IDs that trigger stopping.
	StopTokens []int32

	// Stream enables streaming generation.
	Stream bool

	// EchoPrompt includes the prompt in output.
	EchoPrompt bool

	// Sampling is the sampling configuration.
	Sampling SamplingConfig
}

GenerateConfig configures text generation.

func DefaultGenerateConfig

func DefaultGenerateConfig() GenerateConfig

DefaultGenerateConfig returns sensible defaults for generation.

type GenerateResult

type GenerateResult struct {
	Token   string // Decoded token text
	TokenID int32  // Token ID
	Done    bool   // Is generation complete
	Reason  string // Stop reason: "eos", "max_tokens", "stop_string", "stop_token"
	Error   error  // Error if any
}

GenerateResult is a single result from streaming generation.

type GeneratorOption

type GeneratorOption func(*generatorOptions)

GeneratorOption configures a TextGenerator.

func WithMaxSeqLen

func WithMaxSeqLen(n int) GeneratorOption

WithMaxSeqLen sets the maximum sequence length.

type KVCache

type KVCache interface {
	Clear()
}

KVCache is an interface for key-value caches used in generation.

type LLMModel

type LLMModel interface {
	// Forward runs a forward pass and returns logits.
	// Input shape: [batch, seq_len]
	// Output shape: [batch, seq_len, vocab_size]
	Forward(input *tensor.RawTensor, cache KVCache, startPos int) *tensor.RawTensor

	// VocabSize returns the vocabulary size.
	VocabSize() int
}

LLMModel is the interface for language models used in generation.

type Sampler

type Sampler struct {
	// contains filtered or unexported fields
}

Sampler samples tokens from logits using configurable strategies.

func NewSampler

func NewSampler(config SamplingConfig) *Sampler

NewSampler creates a new sampler with the given configuration.

func (*Sampler) Sample

func (s *Sampler) Sample(logits []float32, previousTokens []int32) int32

Sample returns the next token ID from logits.

Parameters:

  • logits: raw model output, shape [vocab_size] or [..., vocab_size]
  • previousTokens: tokens generated so far (for repetition penalty)

The sampling process:

  1. Apply repetition penalty
  2. Apply temperature scaling
  3. Apply Top-K filtering
  4. Apply Top-P (nucleus) filtering
  5. Apply Min-P filtering
  6. Sample from distribution (or argmax if temperature=0)

func (*Sampler) SampleTensor

func (s *Sampler) SampleTensor(logits *tensor.Tensor[float32, tensor.Backend], previousTokens []int32) int32

SampleTensor samples from a tensor of logits.

type SamplingConfig

type SamplingConfig struct {
	// Temperature controls randomness. 0 = greedy, 1 = normal, >1 = more random.
	Temperature float32

	// TopK limits sampling to top K tokens. 0 = disabled.
	TopK int

	// TopP (nucleus sampling) limits to tokens with cumulative prob < P. 1.0 = disabled.
	TopP float32

	// MinP filters tokens with prob < max_prob * MinP. 0 = disabled.
	MinP float32

	// Repetition control
	RepeatPenalty    float32 // Penalty for repeated tokens. 1.0 = no penalty.
	FrequencyPenalty float32 // Penalty based on frequency. 0 = disabled.
	PresencePenalty  float32 // Penalty for presence. 0 = disabled.
	RepeatWindow     int     // Number of tokens to consider. 0 = all.

	// Seed for reproducibility. -1 = random.
	Seed int64
}

SamplingConfig configures the sampling strategy for text generation.

func DefaultSamplingConfig

func DefaultSamplingConfig() SamplingConfig

DefaultSamplingConfig returns sensible defaults for text generation.

type TextGenerator

type TextGenerator struct {
	// contains filtered or unexported fields
}

TextGenerator generates text using an LLM.

func NewTextGenerator

func NewTextGenerator(
	model LLMModel,
	tok tokenizer.Tokenizer,
	samplingConfig SamplingConfig,
	opts ...GeneratorOption,
) *TextGenerator

NewTextGenerator creates a new text generator.

func (*TextGenerator) Chat

func (g *TextGenerator) Chat(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, config GenerateConfig) (string, error)

Chat generates a response for chat messages. Uses the provided ChatTemplate to format messages into a prompt.

func (*TextGenerator) ChatStream

func (g *TextGenerator) ChatStream(messages []tokenizer.ChatMessage, template tokenizer.ChatTemplate, config GenerateConfig) (<-chan GenerateResult, error)

ChatStream generates a streaming response for chat messages. Uses the provided ChatTemplate to format messages into a prompt.

func (*TextGenerator) ClearCache

func (g *TextGenerator) ClearCache()

ClearCache clears the KV cache.

func (*TextGenerator) Generate

func (g *TextGenerator) Generate(prompt string, config GenerateConfig) (string, error)

Generate generates text from a prompt.

func (*TextGenerator) GenerateStream

func (g *TextGenerator) GenerateStream(prompt string, config GenerateConfig) (<-chan GenerateResult, error)

GenerateStream generates text and returns a channel of results.

func (*TextGenerator) SetCache

func (g *TextGenerator) SetCache(cache KVCache)

SetCache sets the KV cache for generation.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL