tensor

package
v0.7.7 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 6, 2026 License: Apache-2.0 Imports: 1 Imported by: 0

Documentation

Overview

Package tensor provides type-safe tensor operations for the Born ML framework.

Overview

Tensors are the fundamental data structure in Born. This package provides:

  • Generic type-safe tensors (Tensor[T, B])
  • NumPy-style broadcasting
  • Zero-copy operations where possible
  • Device abstraction (CPU, CUDA)

Basic Usage

import (
    "github.com/born-ml/born/tensor"
    "github.com/born-ml/born/backend/cpu"
)

func main() {
    backend := cpu.New()

    // Create tensors
    x := tensor.Zeros[float32](tensor.Shape{2, 3}, backend)
    y := tensor.Ones[float32](tensor.Shape{2, 3}, backend)

    // Tensor operations
    z := x.Add(y)
    result := x.MatMul(y.Transpose())
}

Supported Data Types

The tensor package supports the following data types via the DType constraint:

  • float32, float64 (floating-point)
  • int32, int64 (signed integers)
  • uint8 (unsigned integers, useful for images)
  • bool (boolean masks)

Device Support

Tensors can reside on different devices:

  • CPU: Pure Go implementation (v0.1.0+)
  • WebGPU: Zero-CGO GPU acceleration (v0.2.0+, Windows)
  • CUDA: GPU support (planned for v0.5.0)

Broadcasting

Tensor operations follow NumPy broadcasting rules:

a := tensor.Zeros[float32](tensor.Shape{3, 1}, backend)     // (3, 1)
b := tensor.Ones[float32](tensor.Shape{3, 4}, backend)      // (3, 4)
c := a.Add(b)                                                // (3, 4)

Memory Management

Tensors use zero-copy operations where possible. The underlying data is reference-counted and automatically freed when no longer needed.

Available Operations (v0.3.0+)

Tensor[T, B] provides 31 type-safe operations:

Scalar operations:

y := x.MulScalar(2.0)    // Multiply by scalar
y := x.AddScalar(1.0)    // Add scalar
y := x.SubScalar(0.5)    // Subtract scalar
y := x.DivScalar(2.0)    // Divide by scalar

Math operations:

y := x.Exp()             // Exponential
y := x.Log()             // Natural logarithm
y := x.Sqrt()            // Square root
y := x.Rsqrt()           // Reciprocal square root
y := x.Cos()             // Cosine
y := x.Sin()             // Sine

Comparison operations (return Tensor[bool, B]):

mask := x.Greater(y)     // or x.Gt(y)
mask := x.Lower(y)       // or x.Lt(y)
mask := x.Equal(y)       // or x.Eq(y)

Type conversion:

i := x.Int32()           // Convert to int32
f := x.Float64()         // Convert to float64

See method documentation for full list of operations.

Package tensor provides the public API for tensor operations in the Born ML framework.

The package defines core interfaces and types for type-safe tensor operations:

  • Tensor[T, B]: High-level generic tensor with type safety
  • RawTensor: Low-level tensor interface for advanced use cases
  • Backend: Interface for device-specific compute implementations
  • Shape, DataType, Device: Core type definitions

Example:

backend := cpu.New()
x := tensor.Zeros[float32](tensor.Shape{2, 3}, backend)
y := tensor.Ones[float32](tensor.Shape{2, 3}, backend)
z := x.Add(y)  // Element-wise addition

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Backend

type Backend interface {
	// Element-wise binary operations.
	Add(a, b *RawTensor) *RawTensor // Element-wise addition.
	Sub(a, b *RawTensor) *RawTensor // Element-wise subtraction.
	Mul(a, b *RawTensor) *RawTensor // Element-wise multiplication.
	Div(a, b *RawTensor) *RawTensor // Element-wise division.

	// Matrix operations.
	MatMul(a, b *RawTensor) *RawTensor      // Matrix multiplication.
	BatchMatMul(a, b *RawTensor) *RawTensor // Batched matrix multiplication for 3D/4D tensors.

	// Convolutional operations.
	Conv2D(input, kernel *RawTensor, stride, padding int) *RawTensor                               // 2D convolution.
	MaxPool2D(input *RawTensor, kernelSize, stride int) *RawTensor                                 // 2D max pooling.
	Conv2DInputBackward(input, kernel, grad *RawTensor, stride, padding int) *RawTensor            // Conv2D input gradient.
	Conv2DKernelBackward(input, kernel, grad *RawTensor, stride, padding int) *RawTensor           // Conv2D kernel gradient.
	MaxPool2DBackward(input, grad *RawTensor, maxIndices []int, kernelSize, stride int) *RawTensor // MaxPool2D gradient.

	// Shape operations.
	Reshape(t *RawTensor, newShape Shape) *RawTensor // Reshape tensor.
	Transpose(t *RawTensor, axes ...int) *RawTensor  // Transpose dimensions.

	// Scalar operations (element-wise with scalar).
	MulScalar(x *RawTensor, scalar any) *RawTensor // Multiply by scalar.
	AddScalar(x *RawTensor, scalar any) *RawTensor // Add scalar.
	SubScalar(x *RawTensor, scalar any) *RawTensor // Subtract scalar.
	DivScalar(x *RawTensor, scalar any) *RawTensor // Divide by scalar.

	// Math operations (element-wise).
	Exp(x *RawTensor) *RawTensor   // Exponential.
	Log(x *RawTensor) *RawTensor   // Natural logarithm.
	Sqrt(x *RawTensor) *RawTensor  // Square root.
	Rsqrt(x *RawTensor) *RawTensor // Reciprocal square root (1/sqrt(x)).
	Cos(x *RawTensor) *RawTensor   // Cosine.
	Sin(x *RawTensor) *RawTensor   // Sine.

	// Activation functions.
	Softmax(x *RawTensor, dim int) *RawTensor // Softmax along dimension.

	// Comparison operations (element-wise, return bool tensor).
	Greater(a, b *RawTensor) *RawTensor      // a > b.
	Lower(a, b *RawTensor) *RawTensor        // a < b.
	GreaterEqual(a, b *RawTensor) *RawTensor // a >= b.
	LowerEqual(a, b *RawTensor) *RawTensor   // a <= b.
	Equal(a, b *RawTensor) *RawTensor        // a == b.
	NotEqual(a, b *RawTensor) *RawTensor     // a != b.

	// Boolean operations (element-wise on bool tensors).
	Or(a, b *RawTensor) *RawTensor  // Logical OR.
	And(a, b *RawTensor) *RawTensor // Logical AND.
	Not(x *RawTensor) *RawTensor    // Logical NOT.

	// Reduction operations.
	Sum(x *RawTensor) *RawTensor                            // Total sum (scalar result).
	SumDim(x *RawTensor, dim int, keepDim bool) *RawTensor  // Sum along dimension.
	MeanDim(x *RawTensor, dim int, keepDim bool) *RawTensor // Mean along dimension.
	Argmax(x *RawTensor, dim int) *RawTensor                // Index of maximum value along dimension.

	// Manipulation operations.
	Cat(tensors []*RawTensor, dim int) *RawTensor // Concatenate along dimension.
	Chunk(x *RawTensor, n, dim int) []*RawTensor  // Split into n equal parts.
	Unsqueeze(x *RawTensor, dim int) *RawTensor   // Add dimension of size 1.
	Squeeze(x *RawTensor, dim int) *RawTensor     // Remove dimension of size 1.

	// Indexing operations.
	Gather(x *RawTensor, dim int, index *RawTensor) *RawTensor // Select elements along dim using index tensor.
	Where(condition, x, y *RawTensor) *RawTensor               // Conditional element selection.
	Embedding(weight, indices *RawTensor) *RawTensor           // Lookup embeddings by indices.

	// Shape operations (broadcast).
	Expand(x *RawTensor, shape Shape) *RawTensor // Broadcast to shape.

	// Type conversion.
	Cast(x *RawTensor, dtype DataType) *RawTensor // Cast to different data type.

	// Metadata.
	Name() string   // Backend name (e.g., "CPU", "WebGPU").
	Device() Device // Device type.
}

Backend defines the interface that all compute backends must implement. Backends handle the actual computation for tensor operations.

Implementations:

  • backend/cpu: Pure Go with SIMD optimizations
  • backend/webgpu: Cross-platform GPU compute via WebGPU
  • backend/cuda: NVIDIA GPU via CUDA (planned)
  • backend/vulkan: Cross-platform GPU via Vulkan (planned)
  • backend/metal: Apple GPU via Metal (planned)

Decorator backends for additional functionality:

  • autodiff: Automatic differentiation (wraps any backend)

Example:

import (
    "github.com/born-ml/born/tensor"
    "github.com/born-ml/born/backend/cpu"
)

backend := cpu.New()
x := tensor.Zeros[float32](tensor.Shape{2, 3}, backend)
y := tensor.Ones[float32](tensor.Shape{2, 3}, backend)
z := x.Add(y)  // Uses backend.Add under the hood

type DType

type DType = tensor.DType

DType is a constraint for tensor data types. Supported types: float32, float64, int32, int64, uint8, bool.

type DataType

type DataType = tensor.DataType

DataType represents the underlying data type of a tensor.

const (
	Float32 DataType = tensor.Float32
	Float64 DataType = tensor.Float64
	Int32   DataType = tensor.Int32
	Int64   DataType = tensor.Int64
	Uint8   DataType = tensor.Uint8
	Bool    DataType = tensor.Bool
)

Data type constants.

type Device

type Device = tensor.Device

Device represents the device where tensor data resides.

const (
	CPU    Device = tensor.CPU
	CUDA   Device = tensor.CUDA
	Vulkan Device = tensor.Vulkan
	Metal  Device = tensor.Metal
	WebGPU Device = tensor.WebGPU
)

Device constants.

type RawTensor

type RawTensor = tensor.RawTensor

RawTensor is the low-level tensor representation.

RawTensor provides:

  • Shape and type information via Shape(), DType(), Device()
  • Type-safe data access via AsFloat32(), AsInt64(), etc.
  • Copy-on-Write semantics via Clone()
  • Lazy GPU evaluation support via IsLazy()
  • Reference counting for efficient memory management

Most users should use the high-level Tensor[T, B] type instead.

Example:

raw, _ := tensor.NewRaw(tensor.Shape{2, 3}, tensor.Float32, tensor.CPU)
data := raw.AsFloat32()  // Type-safe access
clone := raw.Clone()     // Shares buffer via reference counting

func NewRaw

func NewRaw(shape Shape, dtype DataType, device Device) (*RawTensor, error)

NewRaw creates a new raw tensor with the given shape, dtype, and device.

This is a low-level function. Most users should use high-level creation functions instead.

type Shape

type Shape = tensor.Shape

Shape represents the dimensions of a tensor. Example: Shape{2, 3, 4} represents a 3D tensor with dimensions 2×3×4.

func BroadcastShapes

func BroadcastShapes(a, b Shape) (Shape, bool, error)

BroadcastShapes computes the broadcast shape for two shapes following NumPy broadcasting rules. Returns the resulting shape and two flags indicating if each operand needs broadcasting.

Example:

resultShape, needsBroadcastA, err := tensor.BroadcastShapes(
    tensor.Shape{3, 1},
    tensor.Shape{3, 4},
)
// resultShape = [3, 4], needsBroadcastA = true

type Tensor

type Tensor[T DType, B Backend] = tensor.Tensor[T, B]

Tensor is a generic type-safe tensor.

T is the data type (float32, float64, int32, int64, uint8, bool). B is the backend implementation (CPU, WebGPU, etc.).

Tensor provides a high-level API for tensor operations with:

  • Type safety via Go generics
  • Automatic differentiation support (via autodiff.Backend)
  • Multiple backend support (CPU, GPU)
  • Efficient memory management with copy-on-write

Example:

backend := cpu.New()
x := tensor.Zeros[float32](tensor.Shape{2, 3}, backend)
y := tensor.Ones[float32](tensor.Shape{2, 3}, backend)
z := x.Add(y)  // Element-wise addition

func Arange

func Arange[T DType, B Backend](start, end T, b B) *Tensor[T, B]

Arange creates a 1D tensor with values from start to end (exclusive).

Example:

backend := cpu.New()
x := tensor.Arange[float32](0, 10, backend)  // [0, 1, 2, ..., 9]

func Cat added in v0.3.0

func Cat[T DType, B Backend](tensors []*Tensor[T, B], dim int) *Tensor[T, B]

Cat concatenates tensors along a dimension.

Example:

backend := cpu.New()
a := tensor.Ones[float32](tensor.Shape{2, 3}, backend)
b := tensor.Zeros[float32](tensor.Shape{2, 3}, backend)
c := tensor.Cat([]*tensor.Tensor[float32, B]{a, b}, 0)  // Shape: [4, 3]

func Eye

func Eye[T DType, B Backend](n int, b B) *Tensor[T, B]

Eye creates a 2D identity matrix.

Example:

backend := cpu.New()
identity := tensor.Eye[float32](3, backend)  // 3x3 identity matrix

func FromSlice

func FromSlice[T DType, B Backend](data []T, shape Shape, b B) (*Tensor[T, B], error)

FromSlice creates a tensor from a Go slice.

Example:

backend := cpu.New()
data := []float32{1, 2, 3, 4, 5, 6}
x, err := tensor.FromSlice(data, tensor.Shape{2, 3}, backend)

func Full

func Full[T DType, B Backend](shape Shape, value T, b B) *Tensor[T, B]

Full creates a tensor filled with a specific value.

Example:

backend := cpu.New()
x := tensor.Full[float32](tensor.Shape{2, 3}, 3.14, backend)

func New

func New[T DType, B Backend](raw *RawTensor, b B) *Tensor[T, B]

New creates a tensor from a raw tensor.

This is a low-level function. Most users should use creation functions like Zeros, Ones, or FromSlice instead.

func Ones

func Ones[T DType, B Backend](shape Shape, b B) *Tensor[T, B]

Ones creates a tensor filled with ones.

Example:

backend := cpu.New()
x := tensor.Ones[float32](tensor.Shape{2, 3}, backend)

func Rand

func Rand[T DType, B Backend](shape Shape, b B) *Tensor[T, B]

Rand creates a tensor filled with random values from uniform distribution U(0, 1).

Example:

backend := cpu.New()
x := tensor.Rand[float32](tensor.Shape{2, 3}, backend)

func Randn

func Randn[T DType, B Backend](shape Shape, b B) *Tensor[T, B]

Randn creates a tensor filled with random values from standard normal distribution N(0, 1).

Example:

backend := cpu.New()
x := tensor.Randn[float32](tensor.Shape{2, 3}, backend)

func Where added in v0.3.0

func Where[T DType, B Backend](cond *Tensor[bool, B], x, y *Tensor[T, B]) *Tensor[T, B]

Where selects elements from x or y based on condition.

Example:

backend := cpu.New()
cond := tensor.Full[bool](tensor.Shape{3}, true, backend)
x := tensor.Full[float32](tensor.Shape{3}, 1.0, backend)
y := tensor.Full[float32](tensor.Shape{3}, 0.0, backend)
result := tensor.Where(cond, x, y)  // [1.0, 1.0, 1.0]

func Zeros

func Zeros[T DType, B Backend](shape Shape, b B) *Tensor[T, B]

Zeros creates a tensor filled with zeros.

Example:

backend := cpu.New()
x := tensor.Zeros[float32](tensor.Shape{2, 3}, backend)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL