middleware

package
v1.0.64 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 25, 2025 License: MIT Imports: 3 Imported by: 0

Documentation

Overview

Package middleware provides a flexible middleware infrastructure for AI provider requests and responses.

Overview

The middleware package enables request transformation, response processing, logging, metrics collection, and other cross-cutting concerns through a composable middleware chain pattern. Middleware can process requests before they are sent to AI providers and responses after they are received.

Key Components

The package provides several key interfaces and types:

  • RequestMiddleware: Processes HTTP requests before sending
  • ResponseMiddleware: Processes HTTP responses after receiving
  • Middleware: Combined interface for both request and response processing
  • MiddlewareChain: Manages ordered middleware execution

Basic Usage

Creating and using a middleware chain:

// Create a new middleware chain
chain := middleware.NewMiddlewareChain()

// Add middleware to the chain
chain.Add(loggingMiddleware).
      Add(metricsMiddleware).
      Add(retryMiddleware)

// Process a request
ctx := context.Background()
req, _ := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", body)
newCtx, newReq, err := chain.ProcessRequest(ctx, req)
if err != nil {
    // Handle error
}

// Make the HTTP call
resp, err := client.Do(newReq)
if err != nil {
    // Handle error
}

// Process the response
newCtx, newResp, err := chain.ProcessResponse(newCtx, newReq, resp)

Creating Custom Middleware

Implementing RequestMiddleware:

type LoggingMiddleware struct {
    logger *log.Logger
}

func (m *LoggingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
    m.logger.Printf("Request: %s %s", req.Method, req.URL)
    // Add request ID to context
    ctx = context.WithValue(ctx, middleware.ContextKeyRequestID, generateID())
    return ctx, req, nil
}

Implementing ResponseMiddleware:

type MetricsMiddleware struct {
    metrics MetricsCollector
}

func (m *MetricsMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
    // Record metrics
    m.metrics.RecordStatusCode(resp.StatusCode)
    return ctx, resp, nil
}

Implementing both interfaces:

type TimingMiddleware struct {
    logger *log.Logger
}

func (m *TimingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
    ctx = context.WithValue(ctx, middleware.ContextKeyStartTime, time.Now())
    return ctx, req, nil
}

func (m *TimingMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
    if startTime, ok := ctx.Value(middleware.ContextKeyStartTime).(time.Time); ok {
        duration := time.Since(startTime)
        m.logger.Printf("Request took %v", duration)
    }
    return ctx, resp, nil
}

Using Function Adapters

For simple middleware, use function adapters:

// Request middleware function
headerMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
    req.Header.Set("User-Agent", "AI-Provider-Kit/1.0")
    return ctx, req, nil
})

// Response middleware function
statusMiddleware := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
    if resp.StatusCode >= 500 {
        return ctx, resp, fmt.Errorf("server error: %d", resp.StatusCode)
    }
    return ctx, resp, nil
})

chain.Add(headerMiddleware).Add(statusMiddleware)

Advanced Chain Operations

Inserting middleware at specific positions:

// Add before another middleware
chain.AddBefore(existingMiddleware, newMiddleware)

// Add after another middleware
chain.AddAfter(existingMiddleware, newMiddleware)

// Remove middleware
chain.Remove(middlewareToRemove)

// Clear all middleware
chain.Clear()

Context Keys

The package provides standard context keys for passing data between middleware:

  • ContextKeyRequestID: Unique request identifier
  • ContextKeyStartTime: Request start time
  • ContextKeyProvider: Provider name (e.g., "openai", "anthropic")
  • ContextKeyModel: Model name (e.g., "gpt-4", "claude-3")
  • ContextKeyMetadata: Arbitrary metadata map
  • ContextKeyError: Error information
  • ContextKeyRetryCount: Retry attempt count

Using context keys:

// Store data in context
ctx = context.WithValue(ctx, middleware.ContextKeyProvider, "openai")
ctx = context.WithValue(ctx, middleware.ContextKeyModel, "gpt-4")
ctx = context.WithValue(ctx, middleware.ContextKeyRetryCount, 0)

// Retrieve data from context
if provider, ok := ctx.Value(middleware.ContextKeyProvider).(string); ok {
    // Use provider
}

Execution Order

The middleware chain executes in a specific order:

  • Request middleware: Execute in the order they were added (first to last)
  • Response middleware: Execute in reverse order (last to first)

This ensures symmetric processing, similar to nested function calls:

Request:  MW1 -> MW2 -> MW3 -> [HTTP Call]
Response: MW1 <- MW2 <- MW3 <- [HTTP Call]

Error Handling

Middleware can return errors to abort the chain:

func (m *ValidationMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
    if req.Header.Get("Authorization") == "" {
        return ctx, req, errors.New("missing authorization header")
    }
    return ctx, req, nil
}

When an error is returned:

  • Request middleware: Subsequent middleware are not executed
  • Response middleware: Subsequent middleware (earlier in chain) are not executed

Thread Safety

DefaultMiddlewareChain is thread-safe and can be used concurrently:

  • Adding/removing middleware uses write locks
  • Processing requests/responses uses read locks
  • Multiple goroutines can process requests/responses simultaneously

Performance Considerations

The middleware chain is designed for efficiency:

  • Minimal allocations during processing
  • Lock-free execution after chain is built
  • Benchmarks show ~900ns per request with 10 middleware

Complete Example

package main

import (
    "context"
    "log"
    "net/http"
    "time"

    "github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

type LoggingMiddleware struct {
    logger *log.Logger
}

func (m *LoggingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
    m.logger.Printf("-> %s %s", req.Method, req.URL)
    ctx = context.WithValue(ctx, middleware.ContextKeyStartTime, time.Now())
    return ctx, req, nil
}

func (m *LoggingMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
    if startTime, ok := ctx.Value(middleware.ContextKeyStartTime).(time.Time); ok {
        m.logger.Printf("<- %s %s (%d) in %v", req.Method, req.URL, resp.StatusCode, time.Since(startTime))
    }
    return ctx, resp, nil
}

func main() {
    // Create middleware chain
    chain := middleware.NewMiddlewareChain()

    // Add middleware
    chain.Add(&LoggingMiddleware{logger: log.Default()})

    // Process request
    req, _ := http.NewRequest("GET", "https://api.example.com/v1/models", nil)
    ctx, req, _ := chain.ProcessRequest(context.Background(), req)

    // Make HTTP call
    client := &http.Client{}
    resp, _ := client.Do(req)
    defer resp.Body.Close()

    // Process response
    _, _, _ = chain.ProcessResponse(ctx, req, resp)
}

Best Practices

  1. Keep middleware focused on a single concern
  2. Use context keys to pass data between middleware
  3. Handle errors appropriately (abort vs. log and continue)
  4. Consider middleware order carefully
  5. Use function adapters for simple, stateless middleware
  6. Avoid holding locks or blocking in middleware
  7. Test middleware in isolation and as part of a chain

Package middleware provides a flexible middleware infrastructure for AI provider requests and responses. It enables request transformation, response processing, logging, metrics collection, and other cross-cutting concerns through a composable middleware chain pattern.

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type CombinedMiddleware

type CombinedMiddleware struct {
	RequestProcessor  RequestMiddleware
	ResponseProcessor ResponseMiddleware
}

CombinedMiddleware is a middleware that implements both RequestMiddleware and ResponseMiddleware

Example

Example of combining multiple middleware types

package main

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

func main() {
	// Header middleware
	headerMw := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
		req.Header.Set("User-Agent", "AI-Provider-Kit/1.0")
		fmt.Println("Added User-Agent header")
		return ctx, req, nil
	})

	// Response validation middleware
	validationMw := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
		if resp.StatusCode >= 400 {
			fmt.Printf("Error response: %d\n", resp.StatusCode)
		} else {
			fmt.Println("Successful response")
		}
		return ctx, resp, nil
	})

	// Combine them
	combined := middleware.NewCombinedMiddleware(headerMw, validationMw)

	chain := middleware.NewMiddlewareChain()
	chain.Add(combined)

	req := httptest.NewRequest("POST", "http://example.com/api", nil)
	ctx := context.Background()

	// Process request
	_, req, _ = chain.ProcessRequest(ctx, req)

	// Process response
	resp := &http.Response{StatusCode: 200, Header: make(http.Header)}
	_, _, _ = chain.ProcessResponse(ctx, req, resp)

}
Output:
Added User-Agent header
Successful response

func NewCombinedMiddleware

func NewCombinedMiddleware(reqProcessor RequestMiddleware, respProcessor ResponseMiddleware) *CombinedMiddleware

NewCombinedMiddleware creates a new combined middleware

func (*CombinedMiddleware) ProcessRequest

func (cm *CombinedMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)

ProcessRequest implements RequestMiddleware

func (*CombinedMiddleware) ProcessResponse

func (cm *CombinedMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)

ProcessResponse implements ResponseMiddleware

type ContextKey

type ContextKey string

ContextKey is the type used for middleware context keys

Example

Example of using context keys to pass data between middleware

package main

import (
	"context"
	"fmt"
	"log"
	"net/http"
	"net/http/httptest"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

func main() {
	// First middleware adds request ID
	requestIDMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
		ctx = context.WithValue(ctx, middleware.ContextKeyRequestID, "req-12345")
		ctx = context.WithValue(ctx, middleware.ContextKeyProvider, "openai")
		return ctx, req, nil
	})

	// Second middleware uses the request ID
	logMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
		requestID := ctx.Value(middleware.ContextKeyRequestID)
		provider := ctx.Value(middleware.ContextKeyProvider)
		fmt.Printf("Processing request %v for provider %v\n", requestID, provider)
		return ctx, req, nil
	})

	chain := middleware.NewMiddlewareChain()
	chain.Add(requestIDMiddleware).Add(logMiddleware)

	req := httptest.NewRequest("GET", "http://example.com/api", nil)
	ctx := context.Background()

	_, _, err := chain.ProcessRequest(ctx, req)
	if err != nil {
		log.Fatal(err)
	}

}
Output:
Processing request req-12345 for provider openai
const (
	// ContextKeyRequestID stores a unique request identifier
	ContextKeyRequestID ContextKey = "middleware:request_id"
	// ContextKeyStartTime stores the request start time
	ContextKeyStartTime ContextKey = "middleware:start_time"
	// ContextKeyProvider stores the provider name
	ContextKeyProvider ContextKey = "middleware:provider"
	// ContextKeyModel stores the model name
	ContextKeyModel ContextKey = "middleware:model"
	// ContextKeyMetadata stores arbitrary metadata
	ContextKeyMetadata ContextKey = "middleware:metadata"
	// ContextKeyError stores error information
	ContextKeyError ContextKey = "middleware:error"
	// ContextKeyRetryCount stores the retry attempt count
	ContextKeyRetryCount ContextKey = "middleware:retry_count"
)

Common context keys for passing data between middleware

type DefaultMiddlewareChain

type DefaultMiddlewareChain struct {
	// contains filtered or unexported fields
}

DefaultMiddlewareChain is the default implementation of MiddlewareChain

Example (ExecutionOrder)

Example showing execution order: requests forward, responses reverse

package main

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

func main() {
	mw1 := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
		fmt.Println("Request: MW1")
		return ctx, req, nil
	})

	mw2 := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
		fmt.Println("Request: MW2")
		return ctx, req, nil
	})

	respMw1 := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
		fmt.Println("Response: MW1")
		return ctx, resp, nil
	})

	respMw2 := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
		fmt.Println("Response: MW2")
		return ctx, resp, nil
	})

	// Create combined middleware
	combined1 := middleware.NewCombinedMiddleware(mw1, respMw1)
	combined2 := middleware.NewCombinedMiddleware(mw2, respMw2)

	chain := middleware.NewMiddlewareChain()
	chain.Add(combined1).Add(combined2)

	req := httptest.NewRequest("GET", "http://example.com/api", nil)
	resp := &http.Response{StatusCode: 200, Header: make(http.Header)}
	ctx := context.Background()

	ctx, req, _ = chain.ProcessRequest(ctx, req)
	fmt.Println("--- HTTP Call ---")
	_, _, _ = chain.ProcessResponse(ctx, req, resp)

}
Output:
Request: MW1
Request: MW2
--- HTTP Call ---
Response: MW2
Response: MW1

func NewMiddlewareChain

func NewMiddlewareChain() *DefaultMiddlewareChain

NewMiddlewareChain creates a new middleware chain

Example

Example of creating a timing middleware that implements both interfaces

package main

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"
	"time"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

// TimingMiddleware is a middleware that tracks request duration
type TimingMiddleware struct{}

func (m *TimingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
	ctx = context.WithValue(ctx, middleware.ContextKeyStartTime, time.Now())
	fmt.Println("Request started")
	return ctx, req, nil
}

func (m *TimingMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
	if startTime, ok := ctx.Value(middleware.ContextKeyStartTime).(time.Time); ok {
		duration := time.Since(startTime)

		if duration >= time.Millisecond {
			fmt.Println("Request completed")
		}
	}
	return ctx, resp, nil
}

func main() {
	// Create chain and add middleware
	chain := middleware.NewMiddlewareChain()
	chain.Add(&TimingMiddleware{})

	// Process request
	req := httptest.NewRequest("GET", "http://example.com/api", nil)
	ctx := context.Background()
	ctx, req, _ = chain.ProcessRequest(ctx, req)

	// Simulate some work
	time.Sleep(1 * time.Millisecond)

	// Process response
	resp := &http.Response{StatusCode: 200, Header: make(http.Header)}
	_, _, _ = chain.ProcessResponse(ctx, req, resp)

}
Output:
Request started
Request completed

func (*DefaultMiddlewareChain) Add

Add appends middleware to the end of the chain

func (*DefaultMiddlewareChain) AddAfter

func (c *DefaultMiddlewareChain) AddAfter(target Middleware, middleware Middleware) bool

AddAfter inserts middleware after another middleware in the chain

func (*DefaultMiddlewareChain) AddBefore

func (c *DefaultMiddlewareChain) AddBefore(target Middleware, middleware Middleware) bool

AddBefore inserts middleware before another middleware in the chain

Example

Example of adding middleware in specific positions

package main

import (
	"context"
	"fmt"
	"log"
	"net/http"
	"net/http/httptest"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

// ExampleMiddleware demonstrates a simple middleware implementation
type ExampleMiddleware struct {
	name string
}

func (m *ExampleMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
	fmt.Printf("Middleware %s\n", m.name)
	return ctx, req, nil
}

func main() {
	mw1 := &ExampleMiddleware{name: "1"}
	mw2 := &ExampleMiddleware{name: "2"}
	mw3 := &ExampleMiddleware{name: "3"}

	chain := middleware.NewMiddlewareChain()
	chain.Add(mw1).Add(mw3)

	// Insert mw2 before mw3
	chain.AddBefore(mw3, mw2)

	req := httptest.NewRequest("GET", "http://example.com/api", nil)
	ctx := context.Background()

	_, _, err := chain.ProcessRequest(ctx, req)
	if err != nil {
		log.Fatal(err)
	}

}
Output:
Middleware 1
Middleware 2
Middleware 3

func (*DefaultMiddlewareChain) Clear

func (c *DefaultMiddlewareChain) Clear()

Clear removes all middleware from the chain

func (*DefaultMiddlewareChain) Len

func (c *DefaultMiddlewareChain) Len() int

Len returns the number of middleware in the chain

func (*DefaultMiddlewareChain) ProcessRequest

func (c *DefaultMiddlewareChain) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)

ProcessRequest executes all request middleware in order

func (*DefaultMiddlewareChain) ProcessResponse

func (c *DefaultMiddlewareChain) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)

ProcessResponse executes all response middleware in reverse order

func (*DefaultMiddlewareChain) Remove

func (c *DefaultMiddlewareChain) Remove(middleware Middleware) bool

Remove removes middleware from the chain

type Middleware

type Middleware interface{}

Middleware is a combined interface for both request and response processing Middleware implementations can implement either or both interfaces

type MiddlewareChain

type MiddlewareChain interface {
	// Add appends middleware to the end of the chain
	Add(middleware Middleware) MiddlewareChain

	// AddBefore inserts middleware before another middleware in the chain
	// Returns false if the target middleware is not found
	AddBefore(target Middleware, middleware Middleware) bool

	// AddAfter inserts middleware after another middleware in the chain
	// Returns false if the target middleware is not found
	AddAfter(target Middleware, middleware Middleware) bool

	// Remove removes middleware from the chain
	// Returns false if the middleware is not found
	Remove(middleware Middleware) bool

	// ProcessRequest executes all request middleware in order
	ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)

	// ProcessResponse executes all response middleware in reverse order
	ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)

	// Clear removes all middleware from the chain
	Clear()

	// Len returns the number of middleware in the chain
	Len() int
}

MiddlewareChain manages an ordered collection of middleware

type RequestMiddleware

type RequestMiddleware interface {
	// ProcessRequest processes an HTTP request before sending
	// It can modify the request, context, or return an error to abort the request
	ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)
}

RequestMiddleware transforms requests before they are sent to the provider

type RequestMiddlewareFunc

type RequestMiddlewareFunc func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)

RequestMiddlewareFunc is a function adapter for RequestMiddleware

Example

Example of creating a simple logging middleware

package main

import (
	"context"
	"fmt"
	"log"
	"net/http"
	"net/http/httptest"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

func main() {
	logMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
		fmt.Printf("Request: %s %s\n", req.Method, req.URL.Path)
		return ctx, req, nil
	})

	chain := middleware.NewMiddlewareChain()
	chain.Add(logMiddleware)

	req := httptest.NewRequest("GET", "http://example.com/api/v1/test", nil)
	ctx := context.Background()

	_, _, err := chain.ProcessRequest(ctx, req)
	if err != nil {
		log.Fatal(err)
	}

}
Output:
Request: GET /api/v1/test

func (RequestMiddlewareFunc) ProcessRequest

func (f RequestMiddlewareFunc) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)

ProcessRequest implements RequestMiddleware

type ResponseMiddleware

type ResponseMiddleware interface {
	// ProcessResponse processes an HTTP response after receiving
	// It can modify the response, context, or return an error
	ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)
}

ResponseMiddleware transforms responses after they are received from the provider

type ResponseMiddlewareFunc

type ResponseMiddlewareFunc func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)

ResponseMiddlewareFunc is a function adapter for ResponseMiddleware

Example

Example of creating a response status checking middleware

package main

import (
	"context"
	"fmt"
	"log"
	"net/http"
	"net/http/httptest"

	"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)

func main() {
	statusMiddleware := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
		fmt.Printf("Response status: %d\n", resp.StatusCode)
		return ctx, resp, nil
	})

	chain := middleware.NewMiddlewareChain()
	chain.Add(statusMiddleware)

	req := httptest.NewRequest("GET", "http://example.com/api", nil)
	resp := &http.Response{
		StatusCode: 200,
		Header:     make(http.Header),
	}
	ctx := context.Background()

	_, _, err := chain.ProcessResponse(ctx, req, resp)
	if err != nil {
		log.Fatal(err)
	}

}
Output:
Response status: 200

func (ResponseMiddlewareFunc) ProcessResponse

func (f ResponseMiddlewareFunc) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)

ProcessResponse implements ResponseMiddleware

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL