Documentation
¶
Overview ¶
Package middleware provides a flexible middleware infrastructure for AI provider requests and responses.
Overview ¶
The middleware package enables request transformation, response processing, logging, metrics collection, and other cross-cutting concerns through a composable middleware chain pattern. Middleware can process requests before they are sent to AI providers and responses after they are received.
Key Components ¶
The package provides several key interfaces and types:
- RequestMiddleware: Processes HTTP requests before sending
- ResponseMiddleware: Processes HTTP responses after receiving
- Middleware: Combined interface for both request and response processing
- MiddlewareChain: Manages ordered middleware execution
Basic Usage ¶
Creating and using a middleware chain:
// Create a new middleware chain
chain := middleware.NewMiddlewareChain()
// Add middleware to the chain
chain.Add(loggingMiddleware).
Add(metricsMiddleware).
Add(retryMiddleware)
// Process a request
ctx := context.Background()
req, _ := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", body)
newCtx, newReq, err := chain.ProcessRequest(ctx, req)
if err != nil {
// Handle error
}
// Make the HTTP call
resp, err := client.Do(newReq)
if err != nil {
// Handle error
}
// Process the response
newCtx, newResp, err := chain.ProcessResponse(newCtx, newReq, resp)
Creating Custom Middleware ¶
Implementing RequestMiddleware:
type LoggingMiddleware struct {
logger *log.Logger
}
func (m *LoggingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
m.logger.Printf("Request: %s %s", req.Method, req.URL)
// Add request ID to context
ctx = context.WithValue(ctx, middleware.ContextKeyRequestID, generateID())
return ctx, req, nil
}
Implementing ResponseMiddleware:
type MetricsMiddleware struct {
metrics MetricsCollector
}
func (m *MetricsMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
// Record metrics
m.metrics.RecordStatusCode(resp.StatusCode)
return ctx, resp, nil
}
Implementing both interfaces:
type TimingMiddleware struct {
logger *log.Logger
}
func (m *TimingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
ctx = context.WithValue(ctx, middleware.ContextKeyStartTime, time.Now())
return ctx, req, nil
}
func (m *TimingMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
if startTime, ok := ctx.Value(middleware.ContextKeyStartTime).(time.Time); ok {
duration := time.Since(startTime)
m.logger.Printf("Request took %v", duration)
}
return ctx, resp, nil
}
Using Function Adapters ¶
For simple middleware, use function adapters:
// Request middleware function
headerMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
req.Header.Set("User-Agent", "AI-Provider-Kit/1.0")
return ctx, req, nil
})
// Response middleware function
statusMiddleware := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
if resp.StatusCode >= 500 {
return ctx, resp, fmt.Errorf("server error: %d", resp.StatusCode)
}
return ctx, resp, nil
})
chain.Add(headerMiddleware).Add(statusMiddleware)
Advanced Chain Operations ¶
Inserting middleware at specific positions:
// Add before another middleware chain.AddBefore(existingMiddleware, newMiddleware) // Add after another middleware chain.AddAfter(existingMiddleware, newMiddleware) // Remove middleware chain.Remove(middlewareToRemove) // Clear all middleware chain.Clear()
Context Keys ¶
The package provides standard context keys for passing data between middleware:
- ContextKeyRequestID: Unique request identifier
- ContextKeyStartTime: Request start time
- ContextKeyProvider: Provider name (e.g., "openai", "anthropic")
- ContextKeyModel: Model name (e.g., "gpt-4", "claude-3")
- ContextKeyMetadata: Arbitrary metadata map
- ContextKeyError: Error information
- ContextKeyRetryCount: Retry attempt count
Using context keys:
// Store data in context
ctx = context.WithValue(ctx, middleware.ContextKeyProvider, "openai")
ctx = context.WithValue(ctx, middleware.ContextKeyModel, "gpt-4")
ctx = context.WithValue(ctx, middleware.ContextKeyRetryCount, 0)
// Retrieve data from context
if provider, ok := ctx.Value(middleware.ContextKeyProvider).(string); ok {
// Use provider
}
Execution Order ¶
The middleware chain executes in a specific order:
- Request middleware: Execute in the order they were added (first to last)
- Response middleware: Execute in reverse order (last to first)
This ensures symmetric processing, similar to nested function calls:
Request: MW1 -> MW2 -> MW3 -> [HTTP Call] Response: MW1 <- MW2 <- MW3 <- [HTTP Call]
Error Handling ¶
Middleware can return errors to abort the chain:
func (m *ValidationMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
if req.Header.Get("Authorization") == "" {
return ctx, req, errors.New("missing authorization header")
}
return ctx, req, nil
}
When an error is returned:
- Request middleware: Subsequent middleware are not executed
- Response middleware: Subsequent middleware (earlier in chain) are not executed
Thread Safety ¶
DefaultMiddlewareChain is thread-safe and can be used concurrently:
- Adding/removing middleware uses write locks
- Processing requests/responses uses read locks
- Multiple goroutines can process requests/responses simultaneously
Performance Considerations ¶
The middleware chain is designed for efficiency:
- Minimal allocations during processing
- Lock-free execution after chain is built
- Benchmarks show ~900ns per request with 10 middleware
Complete Example ¶
package main
import (
"context"
"log"
"net/http"
"time"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
type LoggingMiddleware struct {
logger *log.Logger
}
func (m *LoggingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
m.logger.Printf("-> %s %s", req.Method, req.URL)
ctx = context.WithValue(ctx, middleware.ContextKeyStartTime, time.Now())
return ctx, req, nil
}
func (m *LoggingMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
if startTime, ok := ctx.Value(middleware.ContextKeyStartTime).(time.Time); ok {
m.logger.Printf("<- %s %s (%d) in %v", req.Method, req.URL, resp.StatusCode, time.Since(startTime))
}
return ctx, resp, nil
}
func main() {
// Create middleware chain
chain := middleware.NewMiddlewareChain()
// Add middleware
chain.Add(&LoggingMiddleware{logger: log.Default()})
// Process request
req, _ := http.NewRequest("GET", "https://api.example.com/v1/models", nil)
ctx, req, _ := chain.ProcessRequest(context.Background(), req)
// Make HTTP call
client := &http.Client{}
resp, _ := client.Do(req)
defer resp.Body.Close()
// Process response
_, _, _ = chain.ProcessResponse(ctx, req, resp)
}
Best Practices ¶
- Keep middleware focused on a single concern
- Use context keys to pass data between middleware
- Handle errors appropriately (abort vs. log and continue)
- Consider middleware order carefully
- Use function adapters for simple, stateless middleware
- Avoid holding locks or blocking in middleware
- Test middleware in isolation and as part of a chain
Package middleware provides a flexible middleware infrastructure for AI provider requests and responses. It enables request transformation, response processing, logging, metrics collection, and other cross-cutting concerns through a composable middleware chain pattern.
Index ¶
- type CombinedMiddleware
- type ContextKey
- type DefaultMiddlewareChain
- func (c *DefaultMiddlewareChain) Add(middleware Middleware) MiddlewareChain
- func (c *DefaultMiddlewareChain) AddAfter(target Middleware, middleware Middleware) bool
- func (c *DefaultMiddlewareChain) AddBefore(target Middleware, middleware Middleware) bool
- func (c *DefaultMiddlewareChain) Clear()
- func (c *DefaultMiddlewareChain) Len() int
- func (c *DefaultMiddlewareChain) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)
- func (c *DefaultMiddlewareChain) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)
- func (c *DefaultMiddlewareChain) Remove(middleware Middleware) bool
- type Middleware
- type MiddlewareChain
- type RequestMiddleware
- type RequestMiddlewareFunc
- type ResponseMiddleware
- type ResponseMiddlewareFunc
Examples ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type CombinedMiddleware ¶
type CombinedMiddleware struct {
RequestProcessor RequestMiddleware
ResponseProcessor ResponseMiddleware
}
CombinedMiddleware is a middleware that implements both RequestMiddleware and ResponseMiddleware
Example ¶
Example of combining multiple middleware types
package main
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
func main() {
// Header middleware
headerMw := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
req.Header.Set("User-Agent", "AI-Provider-Kit/1.0")
fmt.Println("Added User-Agent header")
return ctx, req, nil
})
// Response validation middleware
validationMw := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
if resp.StatusCode >= 400 {
fmt.Printf("Error response: %d\n", resp.StatusCode)
} else {
fmt.Println("Successful response")
}
return ctx, resp, nil
})
// Combine them
combined := middleware.NewCombinedMiddleware(headerMw, validationMw)
chain := middleware.NewMiddlewareChain()
chain.Add(combined)
req := httptest.NewRequest("POST", "http://example.com/api", nil)
ctx := context.Background()
// Process request
_, req, _ = chain.ProcessRequest(ctx, req)
// Process response
resp := &http.Response{StatusCode: 200, Header: make(http.Header)}
_, _, _ = chain.ProcessResponse(ctx, req, resp)
}
Output: Added User-Agent header Successful response
func NewCombinedMiddleware ¶
func NewCombinedMiddleware(reqProcessor RequestMiddleware, respProcessor ResponseMiddleware) *CombinedMiddleware
NewCombinedMiddleware creates a new combined middleware
type ContextKey ¶
type ContextKey string
ContextKey is the type used for middleware context keys
Example ¶
Example of using context keys to pass data between middleware
package main
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
func main() {
// First middleware adds request ID
requestIDMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
ctx = context.WithValue(ctx, middleware.ContextKeyRequestID, "req-12345")
ctx = context.WithValue(ctx, middleware.ContextKeyProvider, "openai")
return ctx, req, nil
})
// Second middleware uses the request ID
logMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
requestID := ctx.Value(middleware.ContextKeyRequestID)
provider := ctx.Value(middleware.ContextKeyProvider)
fmt.Printf("Processing request %v for provider %v\n", requestID, provider)
return ctx, req, nil
})
chain := middleware.NewMiddlewareChain()
chain.Add(requestIDMiddleware).Add(logMiddleware)
req := httptest.NewRequest("GET", "http://example.com/api", nil)
ctx := context.Background()
_, _, err := chain.ProcessRequest(ctx, req)
if err != nil {
log.Fatal(err)
}
}
Output: Processing request req-12345 for provider openai
const ( // ContextKeyRequestID stores a unique request identifier ContextKeyRequestID ContextKey = "middleware:request_id" // ContextKeyStartTime stores the request start time ContextKeyStartTime ContextKey = "middleware:start_time" // ContextKeyProvider stores the provider name ContextKeyProvider ContextKey = "middleware:provider" // ContextKeyModel stores the model name ContextKeyModel ContextKey = "middleware:model" // ContextKeyMetadata stores arbitrary metadata ContextKeyMetadata ContextKey = "middleware:metadata" // ContextKeyError stores error information ContextKeyError ContextKey = "middleware:error" // ContextKeyRetryCount stores the retry attempt count ContextKeyRetryCount ContextKey = "middleware:retry_count" )
Common context keys for passing data between middleware
type DefaultMiddlewareChain ¶
type DefaultMiddlewareChain struct {
// contains filtered or unexported fields
}
DefaultMiddlewareChain is the default implementation of MiddlewareChain
Example (ExecutionOrder) ¶
Example showing execution order: requests forward, responses reverse
package main
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
func main() {
mw1 := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
fmt.Println("Request: MW1")
return ctx, req, nil
})
mw2 := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
fmt.Println("Request: MW2")
return ctx, req, nil
})
respMw1 := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
fmt.Println("Response: MW1")
return ctx, resp, nil
})
respMw2 := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
fmt.Println("Response: MW2")
return ctx, resp, nil
})
// Create combined middleware
combined1 := middleware.NewCombinedMiddleware(mw1, respMw1)
combined2 := middleware.NewCombinedMiddleware(mw2, respMw2)
chain := middleware.NewMiddlewareChain()
chain.Add(combined1).Add(combined2)
req := httptest.NewRequest("GET", "http://example.com/api", nil)
resp := &http.Response{StatusCode: 200, Header: make(http.Header)}
ctx := context.Background()
ctx, req, _ = chain.ProcessRequest(ctx, req)
fmt.Println("--- HTTP Call ---")
_, _, _ = chain.ProcessResponse(ctx, req, resp)
}
Output: Request: MW1 Request: MW2 --- HTTP Call --- Response: MW2 Response: MW1
func NewMiddlewareChain ¶
func NewMiddlewareChain() *DefaultMiddlewareChain
NewMiddlewareChain creates a new middleware chain
Example ¶
Example of creating a timing middleware that implements both interfaces
package main
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"time"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
// TimingMiddleware is a middleware that tracks request duration
type TimingMiddleware struct{}
func (m *TimingMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
ctx = context.WithValue(ctx, middleware.ContextKeyStartTime, time.Now())
fmt.Println("Request started")
return ctx, req, nil
}
func (m *TimingMiddleware) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
if startTime, ok := ctx.Value(middleware.ContextKeyStartTime).(time.Time); ok {
duration := time.Since(startTime)
if duration >= time.Millisecond {
fmt.Println("Request completed")
}
}
return ctx, resp, nil
}
func main() {
// Create chain and add middleware
chain := middleware.NewMiddlewareChain()
chain.Add(&TimingMiddleware{})
// Process request
req := httptest.NewRequest("GET", "http://example.com/api", nil)
ctx := context.Background()
ctx, req, _ = chain.ProcessRequest(ctx, req)
// Simulate some work
time.Sleep(1 * time.Millisecond)
// Process response
resp := &http.Response{StatusCode: 200, Header: make(http.Header)}
_, _, _ = chain.ProcessResponse(ctx, req, resp)
}
Output: Request started Request completed
func (*DefaultMiddlewareChain) Add ¶
func (c *DefaultMiddlewareChain) Add(middleware Middleware) MiddlewareChain
Add appends middleware to the end of the chain
func (*DefaultMiddlewareChain) AddAfter ¶
func (c *DefaultMiddlewareChain) AddAfter(target Middleware, middleware Middleware) bool
AddAfter inserts middleware after another middleware in the chain
func (*DefaultMiddlewareChain) AddBefore ¶
func (c *DefaultMiddlewareChain) AddBefore(target Middleware, middleware Middleware) bool
AddBefore inserts middleware before another middleware in the chain
Example ¶
Example of adding middleware in specific positions
package main
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
// ExampleMiddleware demonstrates a simple middleware implementation
type ExampleMiddleware struct {
name string
}
func (m *ExampleMiddleware) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
fmt.Printf("Middleware %s\n", m.name)
return ctx, req, nil
}
func main() {
mw1 := &ExampleMiddleware{name: "1"}
mw2 := &ExampleMiddleware{name: "2"}
mw3 := &ExampleMiddleware{name: "3"}
chain := middleware.NewMiddlewareChain()
chain.Add(mw1).Add(mw3)
// Insert mw2 before mw3
chain.AddBefore(mw3, mw2)
req := httptest.NewRequest("GET", "http://example.com/api", nil)
ctx := context.Background()
_, _, err := chain.ProcessRequest(ctx, req)
if err != nil {
log.Fatal(err)
}
}
Output: Middleware 1 Middleware 2 Middleware 3
func (*DefaultMiddlewareChain) Clear ¶
func (c *DefaultMiddlewareChain) Clear()
Clear removes all middleware from the chain
func (*DefaultMiddlewareChain) Len ¶
func (c *DefaultMiddlewareChain) Len() int
Len returns the number of middleware in the chain
func (*DefaultMiddlewareChain) ProcessRequest ¶
func (c *DefaultMiddlewareChain) ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)
ProcessRequest executes all request middleware in order
func (*DefaultMiddlewareChain) ProcessResponse ¶
func (c *DefaultMiddlewareChain) ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)
ProcessResponse executes all response middleware in reverse order
func (*DefaultMiddlewareChain) Remove ¶
func (c *DefaultMiddlewareChain) Remove(middleware Middleware) bool
Remove removes middleware from the chain
type Middleware ¶
type Middleware interface{}
Middleware is a combined interface for both request and response processing Middleware implementations can implement either or both interfaces
type MiddlewareChain ¶
type MiddlewareChain interface {
// Add appends middleware to the end of the chain
Add(middleware Middleware) MiddlewareChain
// AddBefore inserts middleware before another middleware in the chain
// Returns false if the target middleware is not found
AddBefore(target Middleware, middleware Middleware) bool
// AddAfter inserts middleware after another middleware in the chain
// Returns false if the target middleware is not found
AddAfter(target Middleware, middleware Middleware) bool
// Remove removes middleware from the chain
// Returns false if the middleware is not found
Remove(middleware Middleware) bool
// ProcessRequest executes all request middleware in order
ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)
// ProcessResponse executes all response middleware in reverse order
ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)
// Clear removes all middleware from the chain
Clear()
// Len returns the number of middleware in the chain
Len() int
}
MiddlewareChain manages an ordered collection of middleware
type RequestMiddleware ¶
type RequestMiddleware interface {
// ProcessRequest processes an HTTP request before sending
// It can modify the request, context, or return an error to abort the request
ProcessRequest(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)
}
RequestMiddleware transforms requests before they are sent to the provider
type RequestMiddlewareFunc ¶
type RequestMiddlewareFunc func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error)
RequestMiddlewareFunc is a function adapter for RequestMiddleware
Example ¶
Example of creating a simple logging middleware
package main
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
func main() {
logMiddleware := middleware.RequestMiddlewareFunc(func(ctx context.Context, req *http.Request) (context.Context, *http.Request, error) {
fmt.Printf("Request: %s %s\n", req.Method, req.URL.Path)
return ctx, req, nil
})
chain := middleware.NewMiddlewareChain()
chain.Add(logMiddleware)
req := httptest.NewRequest("GET", "http://example.com/api/v1/test", nil)
ctx := context.Background()
_, _, err := chain.ProcessRequest(ctx, req)
if err != nil {
log.Fatal(err)
}
}
Output: Request: GET /api/v1/test
type ResponseMiddleware ¶
type ResponseMiddleware interface {
// ProcessResponse processes an HTTP response after receiving
// It can modify the response, context, or return an error
ProcessResponse(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)
}
ResponseMiddleware transforms responses after they are received from the provider
type ResponseMiddlewareFunc ¶
type ResponseMiddlewareFunc func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error)
ResponseMiddlewareFunc is a function adapter for ResponseMiddleware
Example ¶
Example of creating a response status checking middleware
package main
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"github.com/cecil-the-coder/ai-provider-kit/internal/common/middleware"
)
func main() {
statusMiddleware := middleware.ResponseMiddlewareFunc(func(ctx context.Context, req *http.Request, resp *http.Response) (context.Context, *http.Response, error) {
fmt.Printf("Response status: %d\n", resp.StatusCode)
return ctx, resp, nil
})
chain := middleware.NewMiddlewareChain()
chain.Add(statusMiddleware)
req := httptest.NewRequest("GET", "http://example.com/api", nil)
resp := &http.Response{
StatusCode: 200,
Header: make(http.Header),
}
ctx := context.Background()
_, _, err := chain.ProcessResponse(ctx, req, resp)
if err != nil {
log.Fatal(err)
}
}
Output: Response status: 200