Documentation
¶
Overview ¶
Package onnx provides ONNX model import functionality for Born ML framework.
This package enables loading and running inference on ONNX (Open Neural Network Exchange) models exported from PyTorch, TensorFlow, and other ML frameworks.
Supported Features ¶
- ONNX format parsing (protobuf-based)
- 30+ standard ONNX operators
- Opset versions 1-21
- Float32 tensor operations
- Named input/output support
Example Usage ¶
import (
"github.com/born-ml/born/onnx"
"github.com/born-ml/born/backend/cpu"
"github.com/born-ml/born/tensor"
)
// Load ONNX model
backend := cpu.New()
model, err := onnx.Load("model.onnx", backend)
if err != nil {
log.Fatal(err)
}
// Create input tensor
input := tensor.FromSlice([]float32{1.0, 2.0, 3.0}, tensor.Shape{1, 3}, backend)
// Run inference
output, err := model.Forward(input.Raw())
if err != nil {
log.Fatal(err)
}
Supported Operators ¶
The following ONNX operators are supported:
- Arithmetic: Add, Sub, Mul, Div, Sqrt, Exp, Log, Sum, Erf
- Logical: Not, And, Or, Xor
- Comparison: Equal, Greater, GreaterOrEqual, Less, LessOrEqual
- Activation: Relu, LeakyRelu, PRelu, Gelu, Silu, Sigmoid, Tanh, Softmax, LogSoftmax
- Matrix: MatMul, Gemm, Transpose, Flatten
- Shape: Reshape, Squeeze, Unsqueeze, Concat, Split, Slice, Gather, Expand
- Normalization: LayerNormalization
- Other: Identity, Dropout, Constant, Cast, ConstantOfShape, Shape, Size, Where, Clip
Use ListSupportedOps to get the complete list of supported operators.
Index ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ListSupportedOps ¶
func ListSupportedOps() []string
ListSupportedOps returns a list of all ONNX operators supported by Born.
Example:
ops := onnx.ListSupportedOps()
for _, op := range ops {
fmt.Println(op)
}
Types ¶
type LoadOptions ¶
type LoadOptions = internalonnx.LoadOptions
LoadOptions configures ONNX model loading behavior.
func DefaultLoadOptions ¶
func DefaultLoadOptions() LoadOptions
DefaultLoadOptions returns the default options for loading ONNX models.
Default configuration:
- Strict mode: enabled (fails on unsupported operators)
- Optimization: enabled
type Model ¶
type Model interface {
// Forward runs inference with a single input tensor.
// For models with multiple inputs, use ForwardNamed.
//
// Returns an error if the model does not have exactly one input
// or one output. In such cases, use ForwardNamed instead.
Forward(input *tensor.RawTensor) (*tensor.RawTensor, error)
// ForwardNamed runs inference with named inputs.
// Returns a map of output name to tensor.
//
// This method supports models with multiple inputs and outputs.
// All input names from InputNames() must be provided.
//
// Example:
//
// inputs := map[string]*tensor.RawTensor{
// "input_ids": inputIDs,
// "attention_mask": attentionMask,
// }
// outputs, err := model.ForwardNamed(inputs)
// if err != nil {
// log.Fatal(err)
// }
// logits := outputs["logits"]
ForwardNamed(inputs map[string]*tensor.RawTensor) (map[string]*tensor.RawTensor, error)
// InputNames returns the names of model inputs.
InputNames() []string
// OutputNames returns the names of model outputs.
OutputNames() []string
// OpsetVersion returns the ONNX opset version used by the model.
OpsetVersion() int64
// Metadata returns model metadata as key-value pairs.
//
// Common metadata keys:
// - "producer_name": Framework that exported the model (e.g., "pytorch")
// - "producer_version": Version of the exporter
// - "domain": Domain of the model (usually "")
// - Custom keys from model.metadata_props
Metadata() map[string]string
}
Model represents a loaded ONNX model ready for inference.
This interface hides the internal implementation and allows for:
- Easy mocking in tests
- Multiple implementations (e.g., optimized versions)
- Decoupling from internal package structure
The model contains the computation graph, weights, and metadata from the original ONNX file. Use Forward or ForwardNamed to run inference.
func Load ¶
Load loads an ONNX model from a file path.
The function parses the ONNX protobuf format, validates operators, and compiles the computation graph for efficient inference.
Returns a Model interface that can be used for inference. The actual implementation is hidden in the internal package.
Example:
backend := cpu.New()
model, err := onnx.Load("resnet18.onnx", backend)
if err != nil {
log.Fatal(err)
}
// Get model info
fmt.Println("Inputs:", model.InputNames())
fmt.Println("Outputs:", model.OutputNames())
fmt.Println("Opset:", model.OpsetVersion())
For custom loading options, pass LoadOptions:
opts := onnx.DefaultLoadOptions()
opts.Strict = false // Allow unsupported ops (will skip them)
model, err := onnx.Load("model.onnx", backend, opts)
func LoadFromBytes ¶
LoadFromBytes loads an ONNX model from raw bytes.
This is useful when the model is embedded in the binary or loaded from a network source.
Returns a Model interface that can be used for inference.
Example:
modelBytes, _ := os.ReadFile("model.onnx")
model, err := onnx.LoadFromBytes(modelBytes, backend)
type ModelInfo ¶
type ModelInfo = internalonnx.ModelInfo
ModelInfo contains metadata about an ONNX model without loading weights.
Use GetModelInfo to quickly inspect a model file before loading.
func GetModelInfo ¶
GetModelInfo extracts metadata from an ONNX file without loading the full model.
This is useful for inspecting model structure, inputs/outputs, and operator requirements before committing to a full load.
Example:
info, err := onnx.GetModelInfo("model.onnx")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Producer: %s\n", info.ProducerName)
fmt.Printf("Opset: %d\n", info.OpsetVersion)
fmt.Printf("Inputs: %v\n", info.InputNames)
fmt.Printf("Outputs: %v\n", info.OutputNames)
fmt.Printf("Operators: %v\n", info.Operators)