tensor

package
v0.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 4, 2025 License: Apache-2.0 Imports: 4 Imported by: 2

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func BroadcastIndex

func BroadcastIndex(index int, shape, outputShape []int, broadcast bool) int

BroadcastIndex computes the index into a tensor for a broadcasting operation.

func BroadcastShapes

func BroadcastShapes(a, b []int) (shape []int, broadcastA, broadcastB bool, err error)

BroadcastShapes computes the output shape for a broadcasting operation.

func SameShape

func SameShape[T Numeric](a, b *Tensor[T]) bool

SameShape checks if two tensors have the same shape.

Types

type Numeric

type Numeric interface {
	~int | ~int8 | ~int16 | ~int32 | ~int64 |
		~uint | ~uint32 | ~uint64 |
		~float32 | ~float64 |
		float8.Float8 |
		float16.Float16
}

Numeric defines the constraint for numeric types that can be used in Tensors.

type Tensor

type Tensor[T Numeric] struct {
	// contains filtered or unexported fields
}

Tensor represents an n-dimensional array of a generic numeric type T.

func New

func New[T Numeric](shape []int, data []T) (*Tensor[T], error)

New creates a new Tensor with the given shape and initializes it with the provided data. If data is nil, it allocates a new slice of the appropriate size. The length of the data slice must match the total number of elements calculated from the shape.

func (*Tensor[T]) At

func (t *Tensor[T]) At(indices ...int) (T, error)

At retrieves the value at the specified indices. It returns an error if the number of indices does not match the tensor's dimensions or if any index is out of bounds.

func (*Tensor[T]) Copy

func (t *Tensor[T]) Copy() *Tensor[T]

Copy creates a deep copy of the tensor.

func (*Tensor[T]) Data

func (t *Tensor[T]) Data() []T

Data returns a slice representing the underlying data of the tensor.

func (*Tensor[T]) Dims

func (t *Tensor[T]) Dims() int

Dims returns the number of dimensions of the tensor.

func (*Tensor[T]) Each

func (t *Tensor[T]) Each(f func(val T))

Each iterates over each element of the tensor and applies the given function. This is useful for operations that need to read every value, respecting strides.

func (*Tensor[T]) Reshape

func (t *Tensor[T]) Reshape(newShape []int) (*Tensor[T], error)

Reshape returns a new Tensor with a different shape that shares the same underlying data. The new shape must have the same total number of elements as the original tensor. This operation is a "view" and does not copy the data.

func (*Tensor[T]) Set

func (t *Tensor[T]) Set(value T, indices ...int) error

Set updates the value at the specified indices. It returns an error if the number of indices does not match the tensor's dimensions, if any index is out of bounds, or if the tensor is a read-only view.

func (*Tensor[T]) SetData

func (t *Tensor[T]) SetData(data []T)

SetData sets the underlying data of the tensor.

func (*Tensor[T]) SetShape

func (t *Tensor[T]) SetShape(shape []int)

SetShape sets the tensor's shape.

func (*Tensor[T]) SetStrides

func (t *Tensor[T]) SetStrides(strides []int)

SetStrides sets the tensor's strides.

func (*Tensor[T]) Shape

func (t *Tensor[T]) Shape() []int

Shape returns a copy of the tensor's shape.

func (*Tensor[T]) ShapeEquals

func (t *Tensor[T]) ShapeEquals(other *Tensor[T]) bool

ShapeEquals returns true if the shapes of two tensors are identical.

func (*Tensor[T]) Size

func (t *Tensor[T]) Size() int

Size returns the total number of elements in the tensor.

func (*Tensor[T]) Slice

func (t *Tensor[T]) Slice(ranges ...[2]int) (*Tensor[T], error)

Slice creates a new Tensor view for the specified range. A slice is defined by a start and end index for each dimension. The returned tensor shares the same underlying data.

func (*Tensor[T]) Strides

func (t *Tensor[T]) Strides() []int

Strides returns a copy of the tensor's strides.

func (*Tensor[T]) String

func (t *Tensor[T]) String() string

String returns a human-readable representation of the tensor.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL