onnx

package
v0.7.10 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 18, 2026 License: Apache-2.0 Imports: 2 Imported by: 0

Documentation

Overview

Package onnx provides ONNX model import functionality for Born ML framework.

This package enables loading and running inference on ONNX (Open Neural Network Exchange) models exported from PyTorch, TensorFlow, and other ML frameworks.

Supported Features

  • ONNX format parsing (protobuf-based)
  • 30+ standard ONNX operators
  • Opset versions 1-21
  • Float32 tensor operations
  • Named input/output support

Example Usage

import (
    "github.com/born-ml/born/onnx"
    "github.com/born-ml/born/backend/cpu"
    "github.com/born-ml/born/tensor"
)

// Load ONNX model
backend := cpu.New()
model, err := onnx.Load("model.onnx", backend)
if err != nil {
    log.Fatal(err)
}

// Create input tensor
input := tensor.FromSlice([]float32{1.0, 2.0, 3.0}, tensor.Shape{1, 3}, backend)

// Run inference
output, err := model.Forward(input.Raw())
if err != nil {
    log.Fatal(err)
}

Supported Operators

The following ONNX operators are supported:

  • Arithmetic: Add, Sub, Mul, Div, Neg, Abs, Sqrt, Exp, Log, Pow
  • Activation: Relu, Sigmoid, Tanh, Softmax, LeakyRelu, Elu, Selu
  • Matrix: MatMul, Gemm, Transpose, Flatten
  • Reduction: ReduceSum, ReduceMean, ReduceMax, ReduceMin
  • Shape: Reshape, Squeeze, Unsqueeze, Concat, Split, Slice, Gather
  • Normalization: BatchNormalization, LayerNormalization
  • Pooling: MaxPool, AveragePool, GlobalAveragePool
  • Convolution: Conv (2D)
  • Other: Constant, Identity, Cast, Clip, Where, Shape

Use ListSupportedOps to get the complete list of supported operators.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func ListSupportedOps

func ListSupportedOps() []string

ListSupportedOps returns a list of all ONNX operators supported by Born.

Example:

ops := onnx.ListSupportedOps()
for _, op := range ops {
    fmt.Println(op)
}

Types

type LoadOptions

type LoadOptions = internalonnx.LoadOptions

LoadOptions configures ONNX model loading behavior.

func DefaultLoadOptions

func DefaultLoadOptions() LoadOptions

DefaultLoadOptions returns the default options for loading ONNX models.

Default configuration:

  • Strict mode: enabled (fails on unsupported operators)
  • Optimization: enabled

type Model

type Model interface {
	// Forward runs inference with a single input tensor.
	// For models with multiple inputs, use ForwardNamed.
	//
	// Returns an error if the model does not have exactly one input
	// or one output. In such cases, use ForwardNamed instead.
	Forward(input *tensor.RawTensor) (*tensor.RawTensor, error)

	// ForwardNamed runs inference with named inputs.
	// Returns a map of output name to tensor.
	//
	// This method supports models with multiple inputs and outputs.
	// All input names from InputNames() must be provided.
	//
	// Example:
	//
	//	inputs := map[string]*tensor.RawTensor{
	//	    "input_ids": inputIDs,
	//	    "attention_mask": attentionMask,
	//	}
	//	outputs, err := model.ForwardNamed(inputs)
	//	if err != nil {
	//	    log.Fatal(err)
	//	}
	//	logits := outputs["logits"]
	ForwardNamed(inputs map[string]*tensor.RawTensor) (map[string]*tensor.RawTensor, error)

	// InputNames returns the names of model inputs.
	InputNames() []string

	// OutputNames returns the names of model outputs.
	OutputNames() []string

	// OpsetVersion returns the ONNX opset version used by the model.
	OpsetVersion() int64

	// Metadata returns model metadata as key-value pairs.
	//
	// Common metadata keys:
	//   - "producer_name": Framework that exported the model (e.g., "pytorch")
	//   - "producer_version": Version of the exporter
	//   - "domain": Domain of the model (usually "")
	//   - Custom keys from model.metadata_props
	Metadata() map[string]string
}

Model represents a loaded ONNX model ready for inference.

This interface hides the internal implementation and allows for:

  • Easy mocking in tests
  • Multiple implementations (e.g., optimized versions)
  • Decoupling from internal package structure

The model contains the computation graph, weights, and metadata from the original ONNX file. Use Forward or ForwardNamed to run inference.

func Load

func Load(path string, backend tensor.Backend, opts ...LoadOptions) (Model, error)

Load loads an ONNX model from a file path.

The function parses the ONNX protobuf format, validates operators, and compiles the computation graph for efficient inference.

Returns a Model interface that can be used for inference. The actual implementation is hidden in the internal package.

Example:

backend := cpu.New()
model, err := onnx.Load("resnet18.onnx", backend)
if err != nil {
    log.Fatal(err)
}

// Get model info
fmt.Println("Inputs:", model.InputNames())
fmt.Println("Outputs:", model.OutputNames())
fmt.Println("Opset:", model.OpsetVersion())

For custom loading options, pass LoadOptions:

opts := onnx.DefaultLoadOptions()
opts.Strict = false // Allow unsupported ops (will skip them)
model, err := onnx.Load("model.onnx", backend, opts)

func LoadFromBytes

func LoadFromBytes(data []byte, backend tensor.Backend, opts ...LoadOptions) (Model, error)

LoadFromBytes loads an ONNX model from raw bytes.

This is useful when the model is embedded in the binary or loaded from a network source.

Returns a Model interface that can be used for inference.

Example:

modelBytes, _ := os.ReadFile("model.onnx")
model, err := onnx.LoadFromBytes(modelBytes, backend)

type ModelInfo

type ModelInfo = internalonnx.ModelInfo

ModelInfo contains metadata about an ONNX model without loading weights.

Use GetModelInfo to quickly inspect a model file before loading.

func GetModelInfo

func GetModelInfo(path string) (*ModelInfo, error)

GetModelInfo extracts metadata from an ONNX file without loading the full model.

This is useful for inspecting model structure, inputs/outputs, and operator requirements before committing to a full load.

Example:

info, err := onnx.GetModelInfo("model.onnx")
if err != nil {
    log.Fatal(err)
}

fmt.Printf("Producer: %s\n", info.ProducerName)
fmt.Printf("Opset: %d\n", info.OpsetVersion)
fmt.Printf("Inputs: %v\n", info.InputNames)
fmt.Printf("Outputs: %v\n", info.OutputNames)
fmt.Printf("Operators: %v\n", info.Operators)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL