nn

package
v1.9.12 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 9, 2025 License: MIT Imports: 3 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Attention added in v1.4.1

type Attention struct {
	// contains filtered or unexported fields
}

func NewAttention added in v1.4.1

func NewAttention(embedDim, numHeads int64, dropout float64) *Attention

func (*Attention) Forward added in v1.4.1

func (a *Attention) Forward(q, k, v, mask *tensor.Tensor, isCausal bool) (*tensor.Tensor, *tensor.Tensor)

func (*Attention) Parameters added in v1.4.1

func (m *Attention) Parameters() []*tensor.Tensor

func (*Attention) ToDevice added in v1.4.1

func (m *Attention) ToDevice(device consts.DeviceType)

func (*Attention) ToScalarType added in v1.4.1

func (m *Attention) ToScalarType(t consts.ScalarType)

type Embedding added in v1.4.1

type Embedding struct {
	// contains filtered or unexported fields
}

func NewEmbedding added in v1.4.1

func NewEmbedding(numEmbeddings, embeddingDim, paddingIdx int64) *Embedding

func (*Embedding) Forward added in v1.4.1

func (m *Embedding) Forward(x *tensor.Tensor) *tensor.Tensor

func (*Embedding) Parameters added in v1.4.1

func (m *Embedding) Parameters() []*tensor.Tensor

func (*Embedding) ToDevice added in v1.4.1

func (m *Embedding) ToDevice(device consts.DeviceType)

func (*Embedding) ToScalarType added in v1.4.1

func (m *Embedding) ToScalarType(t consts.ScalarType)

type LayerNorm

type LayerNorm struct {
	// contains filtered or unexported fields
}

func NewLayerNorm

func NewLayerNorm(shapes ...int64) *LayerNorm

func (*LayerNorm) Forward

func (m *LayerNorm) Forward(x *tensor.Tensor) *tensor.Tensor

func (*LayerNorm) Parameters

func (m *LayerNorm) Parameters() []*tensor.Tensor

func (*LayerNorm) ToDevice

func (m *LayerNorm) ToDevice(device consts.DeviceType)

func (*LayerNorm) ToScalarType

func (m *LayerNorm) ToScalarType(t consts.ScalarType)

type Linear

type Linear struct {
	// contains filtered or unexported fields
}

func NewLinear

func NewLinear(inFeatures, outFeatures int64) *Linear

func (*Linear) Forward

func (m *Linear) Forward(x *tensor.Tensor) *tensor.Tensor

func (*Linear) Parameters

func (m *Linear) Parameters() []*tensor.Tensor

func (*Linear) ToDevice

func (m *Linear) ToDevice(device consts.DeviceType)

func (*Linear) ToScalarType

func (m *Linear) ToScalarType(t consts.ScalarType)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL