openai

package
v0.10.15 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 8, 2026 License: MIT Imports: 21 Imported by: 0

Documentation

Overview

Package openai implements a agent.Provider for any OpenAI-compatible API endpoint (LM Studio, Ollama, OpenAI, Azure, Groq, Together, OpenRouter).

Index

Constants

This section is empty.

Variables

View Source
var ErrEndpointUnreachable = errors.New("openai: endpoint unreachable")

ErrEndpointUnreachable is the sentinel callers match via errors.Is when distinguishing "this endpoint is down / unreachable" from everything else. Wrapped by ReachabilityError.

Functions

func ClassifyHTTPStatus

func ClassifyHTTPStatus(endpoint, operation string, statusCode int, body string) error

ClassifyHTTPStatus returns a ReachabilityError when the status code is server-side (5xx); nil otherwise. Callers combine with ClassifyNetwork for the full picture.

func ClassifyNetwork

func ClassifyNetwork(endpoint, operation string, cause error) error

ClassifyNetwork wraps a transport-level error as a ReachabilityError when it looks like the endpoint was unreachable — connection refused, dial timeouts, TLS handshake failure, DNS resolution failure, mid-response reset. Context cancellation and context-deadline-exceeded are NOT wrapped — those are the caller's own deadline and should bubble as the caller's concern. Returns nil when cause is nil.

func DiscoverModels

func DiscoverModels(ctx context.Context, baseURL, apiKey string) ([]string, error)

DiscoverModels queries the generic /v1/models endpoint through the shared OpenAI-compatible SDK. It is kept here as a package-level compatibility wrapper for existing provider tests and callers inside this package.

func IsModelNotFound

func IsModelNotFound(body string) bool

IsModelNotFound returns true when the response body from a 404 on /v1/chat/completions contains a model-specific error code. Caller decides whether to failover based on this + the pinned flag.

func MatchModelIDs

func MatchModelIDs(requested string, catalog []string) []string

MatchModelIDs returns every catalog entry whose normalized form contains the normalized request as a substring. Normalization lowercases, strips a single leading vendor namespace, and removes all non-alphanumeric separators, so "qwen/qwen3.6" and "Qwen3.6" and "qwen3.6" all match "Qwen3.6-35B-A3B-4bit" and "Qwen3.6-35B-A3B-nvfp4".

The returned slice preserves original catalog case and order. An empty slice means no match; callers are responsible for deciding whether to pass the original request through to the provider unchanged or to escalate.

This is the primary matching primitive since v0.9.2 — it replaces the scalar logic previously in NormalizeModelID. NormalizeModelID is retained as a backward-compatible wrapper.

func NormalizeModelID

func NormalizeModelID(requested string, catalog []string) (string, error)

NormalizeModelID resolves a caller-supplied model name against the server's canonical model catalog (the IDs returned by GET /v1/models).

Prefer MatchModelIDs for new code; this wrapper is retained for backward compatibility with the v0.9.1 call signature. Behaviour:

  • 0 matches → returns the original requested string, no error
  • 1 match → returns the catalog entry, no error
  • 2+ matches → returns "" and an ambiguity error listing the candidates

func SelectModel

func SelectModel(ranked []ScoredModel) string

SelectModel picks the preferred model ID from a ranked list. Returns "" if the list is empty.

func ShouldFailover

func ShouldFailover(err error, pinned bool) bool

ShouldFailover returns true iff err warrants trying the next routing candidate instead of bubbling up. Returns false when:

  • the request is pinned (operator explicitly chose a provider)
  • the error is a caller-level concern (context.Canceled / context.DeadlineExceeded)
  • the error is a request-validation error the next endpoint can't fix (400 Bad Request, generic 404 likely indicating wrong URL)

Returns true for:

  • ReachabilityError (5xx, dial failures, TLS, network resets)
  • 401 / 403 auth failures (another endpoint might have valid auth)
  • 404 responses whose body indicates the model specifically isn't served (IsModelNotFound)
  • 429 rate-limit / quota errors

pinned should be req.Provider != "" captured BEFORE the failover loop.

Types

type Config

type Config struct {
	BaseURL      string // e.g., "http://localhost:1234/v1" for LM Studio
	APIKey       string // optional for local providers
	Model        string // e.g., "qwen3.5-7b", "gpt-4o". Empty = auto-discover.
	ProviderName string // logical provider identity; default "openai"
	// ProviderSystem is the telemetry/cost system identity. When empty, it
	// defaults to "openai". Concrete provider wrappers set their own type.
	ProviderSystem string
	ModelPattern   string // case-insensitive regex to prefer among auto-discovered models
	// KnownModels maps concrete model IDs to catalog target IDs for the
	// agent.openai surface. Models present in this map are ranked higher during
	// auto-selection. Populated by the config layer from the model catalog;
	// nil disables catalog-aware ranking.
	KnownModels map[string]string
	Headers     map[string]string // extra HTTP headers (OpenRouter, Azure, etc.)
	Reasoning   reasoningpolicy.Reasoning
	// Capabilities supplies provider-owned protocol capability claims. When nil,
	// direct openai.Provider callers use OpenAI protocol defaults.
	Capabilities *ProtocolCapabilities
	// UsageCostAttribution extracts provider-owned gateway cost metadata from
	// the raw usage object, when that provider reports one.
	UsageCostAttribution func(rawUsage string) (*agent.CostAttribution, bool)
	// ModelReasoningWire maps a concrete model ID to the catalog
	// reasoning_wire value for that model. Recognized values are "provider"
	// (default), "model_id" (model name encodes reasoning level — strip the
	// reasoning field at serialization), and "none" (model has no reasoning
	// surface — reject explicit non-off requests pre-flight). Models not
	// listed default to "provider", preserving existing behavior.
	ModelReasoningWire map[string]string
	// QuotaHeaderParser, when set, overrides the default OpenAI rate-limit
	// header parser. OpenRouter uses this to install
	// quotaheaders.ParseOpenRouter even though it goes through this
	// OpenAI-compatible provider implementation.
	QuotaHeaderParser func(http.Header, time.Time) quotaheaders.Signal
	// QuotaSignalObserver receives parsed rate-limit signals on every
	// response. The service layer wires this to the provider quota state
	// machine. Both QuotaHeaderParser and QuotaSignalObserver must be set
	// for header-driven exhaustion tracking to activate.
	QuotaSignalObserver func(quotaheaders.Signal)
}

Config holds configuration for the OpenAI-compatible provider.

type HTTPStatusError

type HTTPStatusError struct {
	Endpoint   string
	Operation  string
	StatusCode int
	Body       string
}

HTTPStatusError is a typed client-level HTTP error that carries the status code and response body so ShouldFailover can classify beyond ReachabilityError (which is only for 5xx / network). Callers that surface 4xx errors from openai-compatible servers should wrap them as HTTPStatusError to participate in the failover policy.

func (*HTTPStatusError) Error

func (e *HTTPStatusError) Error() string

type ProtocolCapabilities

type ProtocolCapabilities struct {
	Tools            bool
	Stream           bool
	StructuredOutput bool
	// Thinking reports whether the provider accepts non-standard body fields
	// (openai-go WithJSONSet) used to control model-side reasoning.
	Thinking       bool
	ThinkingFormat ThinkingWireFormat
	// StrictThinkingModelMatch, when true, makes the openai layer return an
	// error if the request carries an explicit reasoning policy while the
	// model does not match the provider's wire format family (e.g. a Qwen
	// wire format with a non-Qwen model). Set true for providers that only
	// serve a single model family (OMLX → Qwen MLX). Providers that host
	// mixed model families (LM Studio can load Qwen, Gemma, Llama, etc.)
	// should leave this false so reasoning controls silently no-op on
	// non-matching models instead of failing the request.
	StrictThinkingModelMatch bool
	// ImplicitGenerationConfig declares that the inference server applies
	// the model's HuggingFace `generation_config.json` automatically when
	// the request omits sampler fields. vLLM does this by default; most
	// other local servers (omlx, lmstudio, lucebox) ship custom presets and
	// either ignore upstream defaults or replace them at repackage time
	// (MLX / GGUF strip generation_config.json), which is why ADR-007's
	// catalog sampling_profiles exist. Routing and the catalog-stale
	// nudge use this flag to distinguish "server has a sane default" from
	// "server will decode greedy" when no catalog profile is supplied.
	ImplicitGenerationConfig bool
}

ProtocolCapabilities declares provider-owned protocol capability claims.

var (
	OpenAIProtocolCapabilities  = ProtocolCapabilities{Tools: true, Stream: true, StructuredOutput: true, Thinking: false}
	UnknownProtocolCapabilities ProtocolCapabilities
)

type Provider

type Provider struct {
	// contains filtered or unexported fields
}

Provider implements agent.Provider for OpenAI-compatible APIs.

func New

func New(cfg Config) *Provider

New creates a new OpenAI-compatible provider.

func (*Provider) Chat

func (p *Provider) Chat(ctx context.Context, messages []agent.Message, tools []agent.ToolDef, opts agent.Options) (agent.Response, error)

func (*Provider) ChatStartMetadata

func (p *Provider) ChatStartMetadata() (string, string, int)

ChatStartMetadata reports the resolved provider system and upstream server identity known when the provider is constructed.

func (*Provider) ChatStream

func (p *Provider) ChatStream(ctx context.Context, messages []agent.Message, tools []agent.ToolDef, opts agent.Options) (<-chan agent.StreamDelta, error)

ChatStream implements agent.StreamingProvider for token-level streaming.

func (*Provider) DiscoveredModels

func (p *Provider) DiscoveredModels() []ScoredModel

DiscoveredModels returns the full ranked list of models discovered from the server's /v1/models endpoint. Returns nil if the provider has a statically configured model or if discovery has not yet run (i.e. no request has been made yet). Call EnsureDiscovered to force discovery without making a chat request.

func (*Provider) EnsureDiscovered

func (p *Provider) EnsureDiscovered(ctx context.Context) error

EnsureDiscovered probes the server's /v1/models endpoint and caches the full ranked model list. It is a no-op when the provider has a statically configured model or when discovery has already run.

func (*Provider) ImplicitGenerationConfig

func (p *Provider) ImplicitGenerationConfig() bool

ImplicitGenerationConfig reports whether the provider's server applies a model-card-derived sampler bundle (`generation_config.json`) when the request omits sampler fields. Used by the agent CLI to tone the catalog-stale nudge per ADR-007 §7.

func (*Provider) SessionStartMetadata

func (p *Provider) SessionStartMetadata() (string, string)

SessionStartMetadata reports the broad provider identity and configured model that should be recorded on session.start events.

func (*Provider) SupportsStream

func (p *Provider) SupportsStream() bool

SupportsStream reports whether `stream: true` returns a well-formed SSE stream with incremental `choices[0].delta` chunks.

func (*Provider) SupportsStructuredOutput

func (p *Provider) SupportsStructuredOutput() bool

SupportsStructuredOutput reports whether the provider honors `response_format: json_object` / tool-use-required semantics to produce a structured (JSON-shaped) response.

func (*Provider) SupportsThinking

func (p *Provider) SupportsThinking() bool

SupportsThinking reports whether the provider accepts non-standard request body fields used to cap or disable model-side reasoning. Providers returning false MUST have those fields stripped at serialization time.

func (*Provider) SupportsTools

func (p *Provider) SupportsTools() bool

SupportsTools reports whether the concrete provider accepts a `tools` field on `/v1/chat/completions` and returns structured `tool_calls` in the response.

type ReachabilityError

type ReachabilityError struct {
	// Endpoint is the base URL of the provider that failed.
	Endpoint string
	// Operation identifies what was being attempted when the failure
	// occurred. Typical values: "probe_models", "chat_completions".
	Operation string
	// StatusCode is the HTTP status code if the failure was HTTP-level;
	// 0 for non-HTTP failures (dial error, timeout, TLS handshake, etc.).
	StatusCode int
	// Cause is the underlying error.
	Cause error
}

ReachabilityError describes a failure attributable to the endpoint being unreachable or returning a server-side 5xx — distinct from request-level client errors (4xx) or model-specific errors. Callers use errors.Is(err, ErrEndpointUnreachable) to detect it.

func (*ReachabilityError) Error

func (e *ReachabilityError) Error() string

func (*ReachabilityError) Is

func (e *ReachabilityError) Is(target error) bool

Is reports whether target is ErrEndpointUnreachable. This makes errors.Is(err, openai.ErrEndpointUnreachable) return true for any *ReachabilityError regardless of the underlying cause.

func (*ReachabilityError) Unwrap

func (e *ReachabilityError) Unwrap() error

Unwrap returns the underlying cause so errors.As / errors.Unwrap traverse through the reachability wrapper.

type ScoredModel

type ScoredModel struct {
	// ID is the model identifier returned by the server's /v1/models endpoint.
	ID string
	// CatalogRef is the catalog target ID if this model is recognized in the
	// model catalog for the provider's surface. Empty for unrecognized models.
	CatalogRef string
	// PatternMatch is true when this model matched the configured model_pattern.
	PatternMatch bool
	// Score summarises the selection preference: 3 = catalog-recognized,
	// 2 = pattern-matched, 1 = uncategorized.
	Score int
}

ScoredModel is a discovered model with a selection preference score. Higher scores are preferred by the auto-selection logic.

func RankModels

func RankModels(candidates []string, knownModels map[string]string, pattern string) ([]ScoredModel, error)

RankModels scores and sorts a list of discovered model IDs by selection preference:

  • Score 3 — catalog-recognized: the model ID appears in knownModels (a map from concrete model ID to catalog target ID, e.g. from Catalog.AllConcreteModels). These are explicitly tracked models; prefer them when auto-selecting.
  • Score 2 — pattern-matched: the model ID matches the case-insensitive pattern regex (pattern == "" means this tier is skipped).
  • Score 1 — uncategorized: known to the server but not in the catalog or pattern.

Within each score tier, the original server-returned order is preserved. Returns an error only if pattern is non-empty and fails to compile.

type ThinkingWireFormat

type ThinkingWireFormat string
const (
	// ThinkingWireFormatThinkingMap sends `thinking: {type, budget_tokens}`.
	ThinkingWireFormatThinkingMap ThinkingWireFormat = "thinking_map"
	// ThinkingWireFormatQwen sends Qwen-family controls:
	// `enable_thinking` and `thinking_budget`.
	ThinkingWireFormatQwen ThinkingWireFormat = "qwen"
	// ThinkingWireFormatOpenRouter sends OpenRouter's nested `reasoning`
	// object with `effort`, `max_tokens`, or `exclude`.
	ThinkingWireFormatOpenRouter ThinkingWireFormat = "openrouter"
)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL