util

package
v0.2.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 11, 2023 License: Apache-2.0 Imports: 11 Imported by: 0

Documentation

Index

Constants

View Source
const (
	WeightName = "pytorch_model.gt"
	ConfigName = "config.json"

	// NOTE. URL form := `$HFpath/ModelName/resolve/main/WeightName`
	HFpath = "https://huggingface.co"
)

Variables

View Source
var ActivationFnMap map[string]ActivationFn = map[string]ActivationFn{
	"gelu":  NewGelu(),
	"relu":  NewRelu(),
	"tanh":  NewTanh(),
	"swish": NewSwish(),
	"mish":  NewMish(),
}
View Source
var (
	CachedDir string = "NOT_SETTING"
)
View Source
var (
	DUMMY_INPUT [][]int64 = [][]int64{
		{7, 6, 0, 0, 1},
		{1, 2, 3, 0, 0},
		{0, 0, 0, 4, 5},
	}
)

Functions

func CachedPath

func CachedPath(modelNameOrPath, fileName string) (resolvedPath string, err error)

CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.

Parameters: - `modelNameOrPath`: model name e.g., "bert-base-uncased" or path to directory contains model/config files. - `fileName`: model or config file name. E.g., "pytorch_model.py", "config.json"

CachedPath does several things consequently: 1. Resolves input string to a fullpath cached filename candidate. 2. Check it at `CachedPath`, if exists, then return the candidate. If not 3. Retrieves and Caches data to `CachedPath` and returns path to cached data

NOTE. default `CachedDir` is at "{$HOME}/.cache/transformer" Custom `CachedDir` can be changed by setting with environment `GO_TRANSFORMER`

func CleanCache

func CleanCache() error

CleanCache removes all files cached in transformer cache directory `CachedDir`.

NOTE. custom `CachedDir` can be changed by setting environment `GO_TRANSFORMER`

func Equal

func Equal(tensorA, tensorB *ts.Tensor) bool

Equal compares 2 tensors in terms of shape, and every element values.

Types

type ActivationFn

type ActivationFn interface {
	// Fwd is a forward pass through x.
	Fwd(x *ts.Tensor) *ts.Tensor
	Name() string
}

ActivationFn is an activation function.

type Dropout

type Dropout struct {
	// contains filtered or unexported fields
}

func NewDropout

func NewDropout(p float64) *Dropout

func (*Dropout) ForwardT

func (d *Dropout) ForwardT(input *ts.Tensor, train bool) (retVal *ts.Tensor)

type GeluActivation

type GeluActivation struct {
	// contains filtered or unexported fields
}

func NewGelu

func NewGelu() GeluActivation

func (GeluActivation) Fwd

func (g GeluActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor)

func (GeluActivation) Name

func (g GeluActivation) Name() (retVal string)

type LinearNoBias

type LinearNoBias struct {
	Ws *ts.Tensor
}

func NewLinearNoBias

func NewLinearNoBias(vs *nn.Path, inDim, outDim int64, config *LinearNoBiasConfig) (*LinearNoBias, error)

func (*LinearNoBias) Forward

func (lnb *LinearNoBias) Forward(xs *ts.Tensor) (retVal *ts.Tensor)

Forward implements Module interface for LinearNoBias

type LinearNoBiasConfig

type LinearNoBiasConfig struct {
	WsInit nn.Init // interface
}

func DefaultLinearNoBiasConfig

func DefaultLinearNoBiasConfig() *LinearNoBiasConfig

type MishActivation

type MishActivation struct {
	// contains filtered or unexported fields
}

func NewMish

func NewMish() MishActivation

func (MishActivation) Fwd

func (m MishActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor)

func (MishActivation) Name

func (m MishActivation) Name() (retVal string)

type ReluActivation

type ReluActivation struct {
	// contains filtered or unexported fields
}

func NewRelu

func NewRelu() ReluActivation

func (ReluActivation) Fwd

func (r ReluActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor)

func (ReluActivation) Name

func (r ReluActivation) Name() (retVal string)

type SwishActivation

type SwishActivation struct {
	// contains filtered or unexported fields
}

func NewSwish

func NewSwish() SwishActivation

func (SwishActivation) Fwd

func (s SwishActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor)

func (SwishActivation) Name

func (s SwishActivation) Name() (retVal string)

type TanhActivation

type TanhActivation struct {
	// contains filtered or unexported fields
}

func NewTanh

func NewTanh() TanhActivation

func (TanhActivation) Fwd

func (t TanhActivation) Fwd(x *ts.Tensor) (retVal *ts.Tensor)

func (TanhActivation) Name

func (t TanhActivation) Name() string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL