tensor

package
v0.0.0-...-5b57bab Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Sep 28, 2025 License: AGPL-3.0 Imports: 16 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func EnableTensorTrace

func EnableTensorTrace()

func EndTensorTrace

func EndTensorTrace()

func SetDefaultDevice

func SetDefaultDevice(device Device)

func ShapeEqual

func ShapeEqual(shape1, shape2 []int) bool

func ShapeSum

func ShapeSum(shape []int) (result int)

Types

type CpuMemoryDevice

type CpuMemoryDevice struct {
	Type DeviceType
	Num  DeviceNum
}

func (*CpuMemoryDevice) GetDeviceType

func (d *CpuMemoryDevice) GetDeviceType() DeviceType

func (*CpuMemoryDevice) GetIndex

func (d *CpuMemoryDevice) GetIndex() DeviceNum

type DataTrace

type DataTrace struct {
	TensorData []*Tensor
	// contains filtered or unexported fields
}

func GetDataTrace

func GetDataTrace() *DataTrace

func (*DataTrace) AppendTensorData

func (m *DataTrace) AppendTensorData(tensor *Tensor)

func (*DataTrace) LoadDataTraceLog

func (m *DataTrace) LoadDataTraceLog()

func (*DataTrace) WriteDataTraceLog

func (m *DataTrace) WriteDataTraceLog()

type Device

type Device interface {
	GetDeviceType() DeviceType
	GetIndex() DeviceNum
}

func GetCpuMemoryDevice

func GetCpuMemoryDevice() Device

func GetDefaultDevice

func GetDefaultDevice() Device

type DeviceNum

type DeviceNum int
const (
	DefaultDeviceNum DeviceNum = 1
)

type DeviceType

type DeviceType int
const (
	DeviceCPU DeviceType = 1 + iota
	DeviceGPU
)

type MemoryPool

type MemoryPool struct {
	// contains filtered or unexported fields
}

func NewFloat32Pool

func NewFloat32Pool() *MemoryPool

func (*MemoryPool) Clear

func (p *MemoryPool) Clear()

func (*MemoryPool) Delete

func (p *MemoryPool) Delete(key *Tensor) []float32

func (*MemoryPool) Get

func (p *MemoryPool) Get(key *Tensor) ([]float32, bool)

func (*MemoryPool) Put

func (p *MemoryPool) Put(key *Tensor, data []float32)

type Tensor

type Tensor struct {
	Data []float32

	Device Device

	//Backward
	Grad         []float32
	RequiresGrad bool
	Parents      []*Tensor
	GradFn       func()
	IsLeaf       bool
	// contains filtered or unexported fields
}

func Add

func Add(a, b *Tensor) *Tensor

func Copy

func Copy(t *Tensor) *Tensor

func HadamardProduct

func HadamardProduct(a, b *Tensor) *Tensor

func Identity

func Identity() *Tensor

func LoadFromCSV

func LoadFromCSV(filename string) (*Tensor, error)

func LoadTensorFromGobFile

func LoadTensorFromGobFile(filename string) (*Tensor, error)

func LookAt

func LookAt(eye, center, up *Tensor) *Tensor

func Multiply

func Multiply(a, b *Tensor) *Tensor

func NewEmptyTensor

func NewEmptyTensor() *Tensor

func NewRandomTensor

func NewRandomTensor(shape []int) *Tensor

func NewTensor

func NewTensor(data []float32, shape []int) *Tensor

func NewTensorFromSlice

func NewTensorFromSlice(data [][]float32) *Tensor

func NewTensorWithShape

func NewTensorWithShape(shape []int) *Tensor

func NewVec3

func NewVec3(x, y, z float32) *Tensor

func Ones

func Ones(shape []int) *Tensor

func Perspective

func Perspective(fovy, aspect, near, far float32) *Tensor

func Random

func Random(shape []int, min, max float32) *Tensor

func RandomNormal

func RandomNormal(shape []int) *Tensor

func Rotate

func Rotate(axis *Tensor, angle float32) *Tensor

func RotateTensor

func RotateTensor(axis *Tensor, angle float32) *Tensor

func StackTensors

func StackTensors(tensors []*Tensor, dim int) (*Tensor, error)

func Subtract

func Subtract(a, b *Tensor) *Tensor

func TranslateMatrix

func TranslateMatrix(v *Tensor) *Tensor

func Transpose

func Transpose(t *Tensor, dims ...int) *Tensor

func Viewport

func Viewport(x, y, w, h float32) *Tensor

func Zeros

func Zeros(shape []int) *Tensor

func ZerosLike

func ZerosLike(t *Tensor) *Tensor

func (*Tensor) Add

func (t *Tensor) Add(other *Tensor) *Tensor

func (*Tensor) AddScalar

func (t *Tensor) AddScalar(scalar float32) *Tensor

func (*Tensor) Add_bak

func (t *Tensor) Add_bak(other *Tensor) *Tensor

func (*Tensor) And

func (t *Tensor) And(other *Tensor) *Tensor

And performs element-wise logical AND operation

func (*Tensor) Apply

func (t *Tensor) Apply(fn func(float32) float32) *Tensor

func (*Tensor) Apply1

func (t *Tensor) Apply1(f func(float32) float32) *Tensor

func (*Tensor) ArgMax

func (t *Tensor) ArgMax() *Tensor

func (*Tensor) At

func (t *Tensor) At(indices ...int) float32

func (*Tensor) Backward

func (t *Tensor) Backward()

func (*Tensor) BakTranspose

func (t *Tensor) BakTranspose() *Tensor

func (*Tensor) Clamp

func (t *Tensor) Clamp(min, max float32) *Tensor

func (*Tensor) Clone

func (t *Tensor) Clone() *Tensor

func (*Tensor) Clone1

func (t *Tensor) Clone1() *Tensor

func (*Tensor) Concat

func (t *Tensor) Concat(other *Tensor, dim int) *Tensor

func (*Tensor) Contain

func (t *Tensor) Contain(other *Tensor) bool

func (*Tensor) ContainPanic

func (t *Tensor) ContainPanic() bool

func (*Tensor) Conv2D

func (t *Tensor) Conv2D(weights *Tensor, kernelSize, stride, padH, padW int) *Tensor

func (*Tensor) Conv2DGradientWeight

func (t *Tensor) Conv2DGradientWeight(input *Tensor, kernelH, kernelW, strideH, strideW, padH, padW int) *Tensor

func (*Tensor) Conv2DTranspose

func (t *Tensor) Conv2DTranspose(weight *Tensor, kernelH, kernelW, strideH, strideW, padH, padW int) *Tensor

func (*Tensor) Copy

func (t *Tensor) Copy() *Tensor

func (*Tensor) Crop

func (t *Tensor) Crop(padding int) *Tensor

func (*Tensor) Cross

func (t *Tensor) Cross(other *Tensor) *Tensor

func (*Tensor) Determinant

func (t *Tensor) Determinant() float32

func (*Tensor) DimSize

func (t *Tensor) DimSize(dim int) int

func (*Tensor) Dimensions

func (t *Tensor) Dimensions() int

func (*Tensor) Div

func (t *Tensor) Div(other *Tensor) *Tensor

func (*Tensor) Div1

func (t *Tensor) Div1(other *Tensor) *Tensor

func (*Tensor) DivScalar

func (t *Tensor) DivScalar(scalar float32) *Tensor

func (*Tensor) Div_bak

func (t *Tensor) Div_bak(other *Tensor) *Tensor

func (*Tensor) Dot

func (t *Tensor) Dot(other *Tensor) float32

func (*Tensor) EnableGrad

func (t *Tensor) EnableGrad() *Tensor

func (*Tensor) Equal

func (t *Tensor) Equal(other *Tensor) bool

func (*Tensor) EqualFloat16

func (t *Tensor) EqualFloat16(other *Tensor) bool

func (*Tensor) EqualFloat32

func (t *Tensor) EqualFloat32(other *Tensor) bool

func (*Tensor) EqualFloat32WithShape

func (t *Tensor) EqualFloat32WithShape(other *Tensor) bool

func (*Tensor) EqualFloat5

func (t *Tensor) EqualFloat5(other *Tensor) bool

func (*Tensor) EqualWithTolerance

func (t *Tensor) EqualWithTolerance(other *Tensor, epsilon float32) bool

func (*Tensor) Exp

func (t *Tensor) Exp() *Tensor

Exp TODO TestCaseCheck

func (*Tensor) Expand

func (t *Tensor) Expand(targetShape []int) *Tensor

func (*Tensor) Fill

func (t *Tensor) Fill(value float32)

func (*Tensor) Flatten

func (t *Tensor) Flatten() *Tensor

func (*Tensor) FlattenByDim

func (t *Tensor) FlattenByDim(startDim, endDim int) *Tensor

func (*Tensor) Gather

func (t *Tensor) Gather(indices *Tensor) *Tensor

func (*Tensor) Get

func (t *Tensor) Get(indices []int) float32

func (*Tensor) GetCol

func (t *Tensor) GetCol(colIdx int) *Tensor

func (*Tensor) GetCols

func (t *Tensor) GetCols(start, end int) *Tensor

func (*Tensor) GetRow

func (t *Tensor) GetRow(row int) *Tensor

func (*Tensor) GetSample

func (t *Tensor) GetSample(batchIdx int) *Tensor

func (*Tensor) GetShape

func (t *Tensor) GetShape() []int

func (*Tensor) GetShapeByNum

func (t *Tensor) GetShapeByNum(num int) int

func (*Tensor) GetValue

func (t *Tensor) GetValue(indices []int) float32

func (*Tensor) Homogeneous

func (t *Tensor) Homogeneous() *Tensor

func (*Tensor) Indices

func (t *Tensor) Indices(i int) []int

func (*Tensor) Inverse

func (t *Tensor) Inverse() *Tensor

func (*Tensor) IsMatrix

func (t *Tensor) IsMatrix() bool

func (*Tensor) IsVector

func (t *Tensor) IsVector() bool

func (*Tensor) Log

func (t *Tensor) Log() *Tensor

Log TODO TestCaseCheck

func (*Tensor) LossMSE

func (pred *Tensor) LossMSE(target *Tensor) *Tensor

func (*Tensor) MaskedFill

func (t *Tensor) MaskedFill(mask *Tensor, value float32) *Tensor

func (*Tensor) MatMul

func (t *Tensor) MatMul(other *Tensor) *Tensor

func (*Tensor) MatMulMatrix

func (a *Tensor) MatMulMatrix(b *Tensor) *Tensor

func (*Tensor) Match

func (t *Tensor) Match() bool

func (*Tensor) MatchPanic

func (t *Tensor) MatchPanic() bool

func (*Tensor) Max

func (t *Tensor) Max() float32

func (*Tensor) Max1

func (t *Tensor) Max1() float32

func (*Tensor) MaxByDim

func (t *Tensor) MaxByDim(dim int, keepdim bool) *Tensor

func (*Tensor) MaxPool

func (t *Tensor) MaxPool(poolSize, stride int) (*Tensor, *Tensor)

func (*Tensor) Mean

func (t *Tensor) Mean() float32

func (*Tensor) MeanTensor

func (t *Tensor) MeanTensor() *Tensor

func (*Tensor) Min

func (t *Tensor) Min() float32

func (*Tensor) Mul

func (t *Tensor) Mul(other *Tensor) *Tensor

func (*Tensor) MulPosition

func (m *Tensor) MulPosition(v *Tensor) *Tensor

func (*Tensor) MulScalar

func (t *Tensor) MulScalar(scalar float32) *Tensor

func (*Tensor) Mul_bak

func (t *Tensor) Mul_bak(other *Tensor) *Tensor

func (*Tensor) Multiply

func (t *Tensor) Multiply(other *Tensor) *Tensor

func (*Tensor) Multiply1

func (t *Tensor) Multiply1(other *Tensor) *Tensor

func (*Tensor) Negate

func (t *Tensor) Negate() *Tensor

func (*Tensor) Normalize

func (t *Tensor) Normalize() *Tensor

func (*Tensor) Not

func (t *Tensor) Not() *Tensor

Not performs element-wise logical NOT operation

func (*Tensor) Or

func (t *Tensor) Or(other *Tensor) *Tensor

Or performs element-wise logical OR operation

func (*Tensor) Pad

func (t *Tensor) Pad(padding int) *Tensor

func (*Tensor) Pad2D

func (t *Tensor) Pad2D(padH, padW int) *Tensor

func (*Tensor) Permute

func (t *Tensor) Permute(perm []int) *Tensor

func (*Tensor) Pow

func (t *Tensor) Pow(exponent float32) *Tensor

func (*Tensor) ReLU

func (t *Tensor) ReLU() *Tensor

func (*Tensor) Reciprocal

func (t *Tensor) Reciprocal() *Tensor

func (*Tensor) ReduceSum

func (t *Tensor) ReduceSum() *Tensor

func (*Tensor) Repeat

func (t *Tensor) Repeat(dim int, repeats int) *Tensor

func (*Tensor) RepeatInterleave

func (t *Tensor) RepeatInterleave(dim int, repeats int) *Tensor

func (*Tensor) RequireGrad

func (t *Tensor) RequireGrad() bool

func (*Tensor) Reshape

func (t *Tensor) Reshape(shape []int) *Tensor

func (*Tensor) Rotate

func (a *Tensor) Rotate(axis *Tensor, angle float32) *Tensor

func (*Tensor) RoundTo

func (t *Tensor) RoundTo(decimals int) *Tensor

func (*Tensor) Save

func (t *Tensor) Save(filename string) error

func (*Tensor) SaveToCSV

func (t *Tensor) SaveToCSV(filename string) error

func (*Tensor) SaveToCSVWithoutShape

func (t *Tensor) SaveToCSVWithoutShape(filename string) error

func (*Tensor) ScatterAdd

func (t *Tensor) ScatterAdd(indices *Tensor, source *Tensor)

func (*Tensor) Set

func (t *Tensor) Set(value float32, indices ...int)

func (*Tensor) Set1

func (t *Tensor) Set1(indices []int, value float32)

func (*Tensor) SetCol

func (t *Tensor) SetCol(colIdx int, data *Tensor)

func (*Tensor) ShapeCopy

func (m *Tensor) ShapeCopy() []int

func (*Tensor) ShapesMatch

func (t *Tensor) ShapesMatch(other *Tensor) bool

func (*Tensor) Sigmoid

func (t *Tensor) Sigmoid() *Tensor

func (*Tensor) Size

func (t *Tensor) Size() int

func (*Tensor) Slice

func (t *Tensor) Slice(start, end, dim int) *Tensor

func (*Tensor) Softmax

func (t *Tensor) Softmax() *Tensor

func (*Tensor) SoftmaxByDim

func (t *Tensor) SoftmaxByDim(dim int) *Tensor

func (*Tensor) SplitLastDim

func (t *Tensor) SplitLastDim(splitPoint, part int) *Tensor

func (*Tensor) Sqrt

func (t *Tensor) Sqrt() *Tensor

func (*Tensor) Squeeze

func (t *Tensor) Squeeze() *Tensor

func (*Tensor) SqueezeSpecific

func (t *Tensor) SqueezeSpecific(dims []int) *Tensor

func (*Tensor) String

func (t *Tensor) String() string

func (*Tensor) Sub

func (t *Tensor) Sub(other *Tensor) *Tensor

func (*Tensor) Sub1

func (t *Tensor) Sub1(other *Tensor) *Tensor

func (*Tensor) SubScalar

func (t *Tensor) SubScalar(scalar float32) *Tensor

func (*Tensor) Sub_bak

func (t *Tensor) Sub_bak(other *Tensor) *Tensor

func (*Tensor) Sum

func (t *Tensor) Sum() float32

func (*Tensor) Sum1

func (t *Tensor) Sum1() float32

func (*Tensor) Sum111

func (t *Tensor) Sum111() *Tensor

func (*Tensor) SumByDim

func (t *Tensor) SumByDim(dim int) *Tensor

func (*Tensor) SumByDim1

func (t *Tensor) SumByDim1(dims []int, keepDims bool) *Tensor

func (*Tensor) SumByDim2

func (t *Tensor) SumByDim2(dim int, keepdim bool) *Tensor

func (*Tensor) Tanh

func (t *Tensor) Tanh() *Tensor

func (*Tensor) TensorData

func (t *Tensor) TensorData() []float32

func (*Tensor) ToDevice

func (t *Tensor) ToDevice(device Device)

func (*Tensor) TraceLogToggle

func (t *Tensor) TraceLogToggle()

func (*Tensor) Transpose

func (t *Tensor) Transpose() *Tensor

func (*Tensor) TransposeByDim

func (t *Tensor) TransposeByDim(dim1, dim2 int) *Tensor

func (*Tensor) Trilu

func (t *Tensor) Trilu(k int, upper bool) *Tensor

func (*Tensor) TriluMask

func (t *Tensor) TriluMask(k int, upper bool) *Tensor

func (*Tensor) X

func (m *Tensor) X() float32

func (*Tensor) Xor

func (t *Tensor) Xor(other *Tensor) *Tensor

Xor performs element-wise logical XOR operation

func (*Tensor) Y

func (m *Tensor) Y() float32

func (*Tensor) Z

func (m *Tensor) Z() float32

func (*Tensor) ZeroGrad

func (t *Tensor) ZeroGrad()

func (*Tensor) ZerosLike

func (t *Tensor) ZerosLike() *Tensor

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL