extensions

package
v1.0.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 29, 2025 License: MIT Imports: 5 Imported by: 0

README

Backend Extensions

Framework for adding custom functionality to the AI Provider Kit backend server.

Overview

Extensions provide a way to hook into the request lifecycle and add custom behavior without modifying core backend code. Use extensions for:

  • Request/response modification - Transform data before/after generation
  • Metrics and monitoring - Track usage, performance, errors
  • Caching - Cache responses to reduce API calls
  • Rate limiting - Per-user or global rate limits
  • Content filtering - Block inappropriate content
  • Custom routing - Add domain-specific endpoints
  • Logging - Advanced logging, audit trails
  • Provider selection - Dynamic provider routing logic

Extension Interface

All extensions implement the Extension interface:

type Extension interface {
    // Metadata
    Name() string
    Version() string
    Description() string
    Dependencies() []string

    // Lifecycle
    Initialize(config map[string]interface{}) error
    Shutdown(ctx context.Context) error

    // Routes
    RegisterRoutes(registrar RouteRegistrar) error

    // Hooks - Generation lifecycle
    BeforeGenerate(ctx context.Context, req *GenerateRequest) error
    AfterGenerate(ctx context.Context, req *GenerateRequest, resp *GenerateResponse) error

    // Hooks - Provider events
    OnProviderError(ctx context.Context, provider types.Provider, err error) error
    OnProviderSelected(ctx context.Context, provider types.Provider) error
}

Lifecycle Hooks

Initialize

Called once during server startup with extension configuration.

func (e *MyExtension) Initialize(config map[string]interface{}) error {
    e.apiKey = config["api_key"].(string)
    e.enabled = config["enabled"].(bool)
    return nil
}

Use for:

  • Loading configuration
  • Initializing clients/connections
  • Setting up resources
Shutdown

Called during graceful server shutdown with a timeout context.

func (e *MyExtension) Shutdown(ctx context.Context) error {
    return e.db.Close()
}

Use for:

  • Closing connections
  • Flushing buffers
  • Cleaning up resources
RegisterRoutes

Called during server initialization to register custom HTTP routes.

func (e *MyExtension) RegisterRoutes(r RouteRegistrar) error {
    r.HandleFunc("/api/metrics", e.handleMetrics)
    r.HandleFunc("/api/cache/clear", e.handleClearCache)
    return nil
}

Use for:

  • Adding custom endpoints
  • Exposing extension functionality via HTTP

Generation Hooks

BeforeGenerate

Called before sending the request to the provider. Can modify the request.

func (e *MyExtension) BeforeGenerate(ctx context.Context, req *GenerateRequest) error {
    // Add custom metadata
    if req.Metadata == nil {
        req.Metadata = make(map[string]interface{})
    }
    req.Metadata["request_id"] = generateID()
    req.Metadata["timestamp"] = time.Now()

    // Modify the prompt
    req.Prompt = sanitizePrompt(req.Prompt)

    return nil
}

Use for:

  • Request validation
  • Prompt modification
  • Adding metadata
  • Content filtering

Return error to: Abort the request

AfterGenerate

Called after receiving response from provider. Can modify the response.

func (e *MyExtension) AfterGenerate(ctx context.Context, req *GenerateRequest, resp *GenerateResponse) error {
    // Add metadata to response
    if resp.Metadata == nil {
        resp.Metadata = make(map[string]interface{})
    }
    resp.Metadata["cached"] = false
    resp.Metadata["processing_time_ms"] = time.Since(startTime).Milliseconds()

    // Modify content
    resp.Content = postProcess(resp.Content)

    return nil
}

Use for:

  • Response modification
  • Caching responses
  • Recording metrics
  • Content filtering

Return error to: Return error to client instead of response

Provider Hooks

OnProviderSelected

Called after a provider is selected but before the generation request.

func (e *MyExtension) OnProviderSelected(ctx context.Context, provider types.Provider) error {
    // Record provider selection
    e.metrics.IncrementProviderUsage(provider.Name())

    // Check provider health
    if err := provider.HealthCheck(ctx); err != nil {
        return fmt.Errorf("selected provider unhealthy: %w", err)
    }

    return nil
}

Use for:

  • Recording provider usage
  • Provider health checks
  • Provider-specific setup

Return error to: Abort the generation request

OnProviderError

Called when a provider returns an error.

func (e *MyExtension) OnProviderError(ctx context.Context, provider types.Provider, err error) error {
    // Log the error
    e.logger.Error("Provider error",
        "provider", provider.Name(),
        "error", err.Error())

    // Record metrics
    e.metrics.IncrementProviderErrors(provider.Name())

    // Send alert if error rate is high
    if e.metrics.ErrorRate(provider.Name()) > 0.5 {
        e.alerting.Send("High error rate for " + provider.Name())
    }

    return nil
}

Use for:

  • Error logging
  • Alerting
  • Error metrics
  • Error recovery

Return value: Not used (error already occurred)

BaseExtension

Use BaseExtension to avoid implementing unused methods:

type MyExtension struct {
    extensions.BaseExtension
    config map[string]interface{}
}

// Only implement what you need
func (e *MyExtension) Name() string { return "my-extension" }
func (e *MyExtension) Version() string { return "1.0.0" }
func (e *MyExtension) Description() string { return "My custom extension" }

func (e *MyExtension) Initialize(config map[string]interface{}) error {
    e.config = config
    return nil
}

func (e *MyExtension) BeforeGenerate(ctx context.Context, req *GenerateRequest) error {
    // Custom logic
    return nil
}

// All other methods have default implementations from BaseExtension

Example Extensions

1. Metrics Extension

Tracks request counts, latencies, and errors.

package main

import (
    "context"
    "sync/atomic"
    "time"

    "github.com/cecil-the-coder/ai-provider-kit/pkg/backend/extensions"
    "github.com/cecil-the-coder/ai-provider-kit/pkg/types"
)

type MetricsExtension struct {
    extensions.BaseExtension
    requestCount  int64
    errorCount    int64
    totalLatency  int64
    providerStats map[string]*ProviderMetrics
}

type ProviderMetrics struct {
    Requests int64
    Errors   int64
}

func NewMetricsExtension() *MetricsExtension {
    return &MetricsExtension{
        providerStats: make(map[string]*ProviderMetrics),
    }
}

func (m *MetricsExtension) Name() string        { return "metrics" }
func (m *MetricsExtension) Version() string     { return "1.0.0" }
func (m *MetricsExtension) Description() string { return "Tracks request metrics" }

func (m *MetricsExtension) Initialize(config map[string]interface{}) error {
    return nil
}

func (m *MetricsExtension) BeforeGenerate(ctx context.Context, req *extensions.GenerateRequest) error {
    atomic.AddInt64(&m.requestCount, 1)

    // Store start time in metadata
    if req.Metadata == nil {
        req.Metadata = make(map[string]interface{})
    }
    req.Metadata["start_time"] = time.Now()

    return nil
}

func (m *MetricsExtension) AfterGenerate(ctx context.Context, req *extensions.GenerateRequest, resp *extensions.GenerateResponse) error {
    // Calculate latency
    if startTime, ok := req.Metadata["start_time"].(time.Time); ok {
        latency := time.Since(startTime).Milliseconds()
        atomic.AddInt64(&m.totalLatency, latency)
    }

    return nil
}

func (m *MetricsExtension) OnProviderSelected(ctx context.Context, provider types.Provider) error {
    stats, ok := m.providerStats[provider.Name()]
    if !ok {
        stats = &ProviderMetrics{}
        m.providerStats[provider.Name()] = stats
    }
    atomic.AddInt64(&stats.Requests, 1)
    return nil
}

func (m *MetricsExtension) OnProviderError(ctx context.Context, provider types.Provider, err error) error {
    atomic.AddInt64(&m.errorCount, 1)

    stats, ok := m.providerStats[provider.Name()]
    if ok {
        atomic.AddInt64(&stats.Errors, 1)
    }

    return nil
}

func (m *MetricsExtension) RegisterRoutes(r extensions.RouteRegistrar) error {
    r.HandleFunc("/api/metrics", m.handleMetrics)
    return nil
}

func (m *MetricsExtension) handleMetrics(w http.ResponseWriter, r *http.Request) {
    requests := atomic.LoadInt64(&m.requestCount)
    errors := atomic.LoadInt64(&m.errorCount)
    latency := atomic.LoadInt64(&m.totalLatency)

    avgLatency := int64(0)
    if requests > 0 {
        avgLatency = latency / requests
    }

    response := map[string]interface{}{
        "total_requests":       requests,
        "total_errors":         errors,
        "average_latency_ms":   avgLatency,
        "error_rate":           float64(errors) / float64(requests),
        "provider_stats":       m.providerStats,
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(response)
}
2. Caching Extension

Caches responses to reduce API calls and costs.

package main

import (
    "context"
    "crypto/sha256"
    "encoding/hex"
    "encoding/json"
    "sync"

    "github.com/cecil-the-coder/ai-provider-kit/pkg/backend/extensions"
)

type CachingExtension struct {
    extensions.BaseExtension
    cache map[string]*CacheEntry
    mu    sync.RWMutex
    ttl   time.Duration
}

type CacheEntry struct {
    Response  *extensions.GenerateResponse
    Timestamp time.Time
}

func NewCachingExtension() *CachingExtension {
    return &CachingExtension{
        cache: make(map[string]*CacheEntry),
        ttl:   5 * time.Minute,
    }
}

func (c *CachingExtension) Name() string        { return "cache" }
func (c *CachingExtension) Version() string     { return "1.0.0" }
func (c *CachingExtension) Description() string { return "Caches responses" }

func (c *CachingExtension) Initialize(config map[string]interface{}) error {
    if ttl, ok := config["ttl_seconds"].(int); ok {
        c.ttl = time.Duration(ttl) * time.Second
    }
    return nil
}

func (c *CachingExtension) BeforeGenerate(ctx context.Context, req *extensions.GenerateRequest) error {
    // Check cache
    key := c.cacheKey(req)

    c.mu.RLock()
    entry, exists := c.cache[key]
    c.mu.RUnlock()

    if exists && time.Since(entry.Timestamp) < c.ttl {
        // Cache hit - store in metadata to skip provider
        if req.Metadata == nil {
            req.Metadata = make(map[string]interface{})
        }
        req.Metadata["cached_response"] = entry.Response
    }

    return nil
}

func (c *CachingExtension) AfterGenerate(ctx context.Context, req *extensions.GenerateRequest, resp *extensions.GenerateResponse) error {
    // Check if response came from cache
    if cached, ok := req.Metadata["cached_response"].(*extensions.GenerateResponse); ok {
        // Use cached response
        *resp = *cached
        if resp.Metadata == nil {
            resp.Metadata = make(map[string]interface{})
        }
        resp.Metadata["from_cache"] = true
        return nil
    }

    // Store in cache
    key := c.cacheKey(req)
    c.mu.Lock()
    c.cache[key] = &CacheEntry{
        Response:  resp,
        Timestamp: time.Now(),
    }
    c.mu.Unlock()

    if resp.Metadata == nil {
        resp.Metadata = make(map[string]interface{})
    }
    resp.Metadata["from_cache"] = false

    return nil
}

func (c *CachingExtension) cacheKey(req *extensions.GenerateRequest) string {
    data, _ := json.Marshal(map[string]interface{}{
        "provider": req.Provider,
        "model":    req.Model,
        "prompt":   req.Prompt,
        "temp":     req.Temperature,
        "max":      req.MaxTokens,
    })

    hash := sha256.Sum256(data)
    return hex.EncodeToString(hash[:])
}

func (c *CachingExtension) RegisterRoutes(r extensions.RouteRegistrar) error {
    r.HandleFunc("/api/cache/clear", c.handleClearCache)
    r.HandleFunc("/api/cache/stats", c.handleStats)
    return nil
}

func (c *CachingExtension) handleClearCache(w http.ResponseWriter, r *http.Request) {
    c.mu.Lock()
    c.cache = make(map[string]*CacheEntry)
    c.mu.Unlock()

    w.WriteHeader(http.StatusOK)
    json.NewEncoder(w).Encode(map[string]string{"status": "cleared"})
}

func (c *CachingExtension) handleStats(w http.ResponseWriter, r *http.Request) {
    c.mu.RLock()
    size := len(c.cache)
    c.mu.RUnlock()

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(map[string]interface{}{
        "entries": size,
        "ttl_seconds": int(c.ttl.Seconds()),
    })
}
3. Content Filter Extension

Filters inappropriate content from requests and responses.

package main

import (
    "context"
    "regexp"
    "strings"

    "github.com/cecil-the-coder/ai-provider-kit/pkg/backend/extensions"
)

type ContentFilterExtension struct {
    extensions.BaseExtension
    blockedPatterns []*regexp.Regexp
    replacements    map[string]string
}

func NewContentFilterExtension() *ContentFilterExtension {
    return &ContentFilterExtension{
        blockedPatterns: []*regexp.Regexp{
            regexp.MustCompile(`(?i)offensive-word-1`),
            regexp.MustCompile(`(?i)offensive-word-2`),
        },
        replacements: map[string]string{
            "badword": "***",
        },
    }
}

func (f *ContentFilterExtension) Name() string        { return "content-filter" }
func (f *ContentFilterExtension) Version() string     { return "1.0.0" }
func (f *ContentFilterExtension) Description() string { return "Filters inappropriate content" }

func (f *ContentFilterExtension) Initialize(config map[string]interface{}) error {
    // Load custom patterns from config
    if patterns, ok := config["blocked_patterns"].([]string); ok {
        for _, pattern := range patterns {
            if re, err := regexp.Compile(pattern); err == nil {
                f.blockedPatterns = append(f.blockedPatterns, re)
            }
        }
    }
    return nil
}

func (f *ContentFilterExtension) BeforeGenerate(ctx context.Context, req *extensions.GenerateRequest) error {
    // Check for blocked content
    for _, pattern := range f.blockedPatterns {
        if pattern.MatchString(req.Prompt) {
            return fmt.Errorf("request contains inappropriate content")
        }
    }

    // Apply replacements
    for bad, good := range f.replacements {
        req.Prompt = strings.ReplaceAll(req.Prompt, bad, good)
    }

    return nil
}

func (f *ContentFilterExtension) AfterGenerate(ctx context.Context, req *extensions.GenerateRequest, resp *extensions.GenerateResponse) error {
    // Filter response content
    for _, pattern := range f.blockedPatterns {
        if pattern.MatchString(resp.Content) {
            return fmt.Errorf("response contains inappropriate content")
        }
    }

    // Apply replacements to response
    for bad, good := range f.replacements {
        resp.Content = strings.ReplaceAll(resp.Content, bad, good)
    }

    return nil
}

Registration

Programmatic Registration
package main

import (
    "github.com/cecil-the-coder/ai-provider-kit/pkg/backend"
    "github.com/cecil-the-coder/ai-provider-kit/pkg/backendtypes"
)

func main() {
    config := backendtypes.BackendConfig{
        Server: backendtypes.ServerConfig{
            Host: "0.0.0.0",
            Port: 8080,
        },
    }

    server := backend.NewServer(config, providers)

    // Register extensions
    server.RegisterExtension(NewMetricsExtension())
    server.RegisterExtension(NewCachingExtension())
    server.RegisterExtension(NewContentFilterExtension())

    server.Start()
}
Configuration-Based Registration
extensions:
  metrics:
    enabled: true
    config: {}

  cache:
    enabled: true
    config:
      ttl_seconds: 300

  content-filter:
    enabled: true
    config:
      blocked_patterns:
        - "(?i)spam"
        - "(?i)inappropriate"
config := backendtypes.BackendConfig{
    Extensions: map[string]backendtypes.ExtensionConfig{
        "metrics": {
            Enabled: true,
            Config:  map[string]interface{}{},
        },
        "cache": {
            Enabled: true,
            Config: map[string]interface{}{
                "ttl_seconds": 300,
            },
        },
    },
}

server := backend.NewServer(config, providers)
// Extensions are automatically initialized from config

Extension Dependencies

Extensions can declare dependencies on other extensions:

type MyExtension struct {
    extensions.BaseExtension
    metrics *MetricsExtension
}

func (e *MyExtension) Dependencies() []string {
    return []string{"metrics"}
}

func (e *MyExtension) Initialize(config map[string]interface{}) error {
    // Get metrics extension from registry
    registry := config["registry"].(extensions.ExtensionRegistry)
    metricsExt, _ := registry.Get("metrics")
    e.metrics = metricsExt.(*MetricsExtension)

    return nil
}

The registry ensures extensions are initialized in dependency order.

Best Practices

  1. Keep extensions focused - One responsibility per extension
  2. Use BaseExtension - Only implement hooks you need
  3. Handle errors gracefully - Don't crash the server
  4. Make extensions configurable - Use Initialize() config
  5. Add custom routes - Expose extension functionality
  6. Thread safety - Use mutexes for shared state
  7. Document configuration - Explain config options
  8. Test thoroughly - Extensions can break the server
  9. Version your extensions - Track compatibility
  10. Clean up resources - Use Shutdown() properly

Common Patterns

Conditional Execution
func (e *MyExtension) BeforeGenerate(ctx context.Context, req *GenerateRequest) error {
    // Only run for specific providers
    if req.Provider != "openai" {
        return nil
    }

    // Only run for specific models
    if !strings.HasPrefix(req.Model, "gpt-4") {
        return nil
    }

    // Custom logic
    return nil
}
Metadata Communication
// In BeforeGenerate
req.Metadata["custom_flag"] = true

// In AfterGenerate
if req.Metadata["custom_flag"].(bool) {
    // Custom processing
}
Error Handling
func (e *MyExtension) BeforeGenerate(ctx context.Context, req *GenerateRequest) error {
    if err := e.validate(req); err != nil {
        // Log but don't fail the request
        log.Printf("Validation warning: %v", err)
        return nil
    }

    if err := e.criticalCheck(req); err != nil {
        // Abort the request
        return fmt.Errorf("critical error: %w", err)
    }

    return nil
}

Troubleshooting

Extension not called
  • Check if extension is registered
  • Verify extension is enabled in config
  • Ensure Initialize() succeeded
  • Check logs for initialization errors
Extension causing errors
  • Add logging to track execution
  • Return nil from hooks unless aborting is required
  • Check for nil pointer dereferences
  • Verify thread safety with concurrent requests
Extension slowing down requests
  • Profile extension code
  • Minimize work in BeforeGenerate/AfterGenerate
  • Use goroutines for non-blocking operations
  • Consider caching expensive operations
Dependencies not working
  • Verify dependency names match exactly
  • Check initialization order in logs
  • Ensure dependent extension is registered
  • Use registry.Get() to access dependencies

Testing Extensions

package main

import (
    "context"
    "testing"

    "github.com/cecil-the-coder/ai-provider-kit/pkg/backend/extensions"
)

func TestMetricsExtension(t *testing.T) {
    ext := NewMetricsExtension()

    // Test initialization
    err := ext.Initialize(map[string]interface{}{})
    if err != nil {
        t.Fatalf("Initialize failed: %v", err)
    }

    // Test BeforeGenerate
    req := &extensions.GenerateRequest{
        Prompt: "test",
    }

    err = ext.BeforeGenerate(context.Background(), req)
    if err != nil {
        t.Fatalf("BeforeGenerate failed: %v", err)
    }

    // Verify request count
    if ext.requestCount != 1 {
        t.Errorf("Expected 1 request, got %d", ext.requestCount)
    }
}

Further Reading

Documentation

Overview

Package extensions provides a plugin system for extending backend functionality. It defines the Extension interface for lifecycle management, route registration, and hooks for generation events, along with an ExtensionRegistry for managing multiple extensions.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type BaseExtension

type BaseExtension struct{}

BaseExtension provides default implementations for optional methods

func (*BaseExtension) AfterGenerate

func (b *BaseExtension) AfterGenerate(ctx context.Context, req *GenerateRequest, resp *GenerateResponse) error

func (*BaseExtension) BeforeGenerate

func (b *BaseExtension) BeforeGenerate(ctx context.Context, req *GenerateRequest) error

func (*BaseExtension) Dependencies

func (b *BaseExtension) Dependencies() []string

func (*BaseExtension) OnProviderError

func (b *BaseExtension) OnProviderError(ctx context.Context, provider types.Provider, err error) error

func (*BaseExtension) OnProviderSelected

func (b *BaseExtension) OnProviderSelected(ctx context.Context, provider types.Provider) error

func (*BaseExtension) RegisterRoutes

func (b *BaseExtension) RegisterRoutes(r RouteRegistrar) error

func (*BaseExtension) Shutdown

func (b *BaseExtension) Shutdown(ctx context.Context) error

type Extension

type Extension interface {
	Name() string
	Version() string
	Description() string
	Dependencies() []string

	Initialize(config map[string]interface{}) error
	Shutdown(ctx context.Context) error

	RegisterRoutes(registrar RouteRegistrar) error

	BeforeGenerate(ctx context.Context, req *GenerateRequest) error
	AfterGenerate(ctx context.Context, req *GenerateRequest, resp *GenerateResponse) error

	OnProviderError(ctx context.Context, provider types.Provider, err error) error
	OnProviderSelected(ctx context.Context, provider types.Provider) error
}

Extension defines the interface for backend extensions

type ExtensionConfig

type ExtensionConfig struct {
	Enabled bool                   `yaml:"enabled"`
	Config  map[string]interface{} `yaml:"config"`
}

ExtensionConfig is a local type until backendtypes is ready

type ExtensionRegistry

type ExtensionRegistry interface {
	Register(ext Extension) error
	Get(name string) (Extension, bool)
	List() []Extension
	Initialize(configs map[string]ExtensionConfig) error
	Shutdown(ctx context.Context) error
}

ExtensionRegistry manages extension lifecycle

func NewRegistry

func NewRegistry() ExtensionRegistry

type GenerateRequest

type GenerateRequest struct {
	Provider    string                 `json:"provider,omitempty"`
	Model       string                 `json:"model,omitempty"`
	Prompt      string                 `json:"prompt"`
	MaxTokens   int                    `json:"max_tokens,omitempty"`
	Temperature float64                `json:"temperature,omitempty"`
	Stream      bool                   `json:"stream,omitempty"`
	Metadata    map[string]interface{} `json:"metadata,omitempty"`
}

GenerateRequest is a local type until backendtypes is ready

type GenerateResponse

type GenerateResponse struct {
	Content  string                 `json:"content"`
	Model    string                 `json:"model"`
	Provider string                 `json:"provider"`
	Metadata map[string]interface{} `json:"metadata,omitempty"`
}

GenerateResponse is a local type until backendtypes is ready

type RouteRegistrar

type RouteRegistrar interface {
	Handle(pattern string, handler http.Handler)
	HandleFunc(pattern string, handler http.HandlerFunc)
}

RouteRegistrar allows extensions to register custom routes

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL