Documentation
¶
Index ¶
- Constants
- Variables
- func CachedPath(modelNameOrPath, fileName string) (resolvedPath string, err error)
- func CleanCache() error
- func Equal(tensorA, tensorB *ts.Tensor) bool
- type ActivationFn
- type Dropout
- type GeluActivation
- type LinearNoBias
- type LinearNoBiasConfig
- type MishActivation
- type ReluActivation
- type SwishActivation
- type TanhActivation
Constants ¶
const ( WeightName = "pytorch_model.gt" ConfigName = "config.json" // NOTE. URL form := `$HFpath/ModelName/resolve/main/WeightName` HFpath = "https://huggingface.co" )
Variables ¶
var ActivationFnMap map[string]ActivationFn = map[string]ActivationFn{ "gelu": NewGelu(), "relu": NewRelu(), "tanh": NewTanh(), "swish": NewSwish(), "mish": NewMish(), }
var (
CachedDir string = "NOT_SETTING"
)
var (
DUMMY_INPUT [][]int64 = [][]int64{
{7, 6, 0, 0, 1},
{1, 2, 3, 0, 0},
{0, 0, 0, 4, 5},
}
)
var Gelu = GeluActivation{}
var Mish = MishActivation{}
var Relu = ReluActivation{}
var Swish = SwishActivation{}
var Tanh = TanhActivation{}
Functions ¶
func CachedPath ¶
CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
Parameters: - `modelNameOrPath`: model name e.g., "bert-base-uncased" or path to directory contains model/config files. - `fileName`: model or config file name. E.g., "pytorch_model.py", "config.json"
CachedPath does several things consequently: 1. Resolves input string to a fullpath cached filename candidate. 2. Check it at `CachedPath`, if exists, then return the candidate. If not 3. Retrieves and Caches data to `CachedPath` and returns path to cached data
NOTE. default `CachedDir` is at "{$HOME}/.cache/transformer" Custom `CachedDir` can be changed by setting with environment `GO_TRANSFORMER`
func CleanCache ¶
func CleanCache() error
CleanCache removes all files cached in transformer cache directory `CachedDir`.
NOTE. custom `CachedDir` can be changed by setting environment `GO_TRANSFORMER`
Types ¶
type ActivationFn ¶
type ActivationFn interface {
// Fwd is a forward pass through x.
Fwd(x *ts.Tensor) *ts.Tensor
Name() string
}
ActivationFn is an activation function.
type GeluActivation ¶
type GeluActivation struct {
// contains filtered or unexported fields
}
func NewGelu ¶
func NewGelu() GeluActivation
func (GeluActivation) Name ¶
func (g GeluActivation) Name() (retVal string)
type LinearNoBias ¶
func NewLinearNoBias ¶
func NewLinearNoBias(vs *nn.Path, inDim, outDim int64, config *LinearNoBiasConfig) (*LinearNoBias, error)
type LinearNoBiasConfig ¶
func DefaultLinearNoBiasConfig ¶
func DefaultLinearNoBiasConfig() *LinearNoBiasConfig
type MishActivation ¶
type MishActivation struct {
// contains filtered or unexported fields
}
func NewMish ¶
func NewMish() MishActivation
func (MishActivation) Name ¶
func (m MishActivation) Name() (retVal string)
type ReluActivation ¶
type ReluActivation struct {
// contains filtered or unexported fields
}
func NewRelu ¶
func NewRelu() ReluActivation
func (ReluActivation) Name ¶
func (r ReluActivation) Name() (retVal string)
type SwishActivation ¶
type SwishActivation struct {
// contains filtered or unexported fields
}
func NewSwish ¶
func NewSwish() SwishActivation
func (SwishActivation) Name ¶
func (s SwishActivation) Name() (retVal string)
type TanhActivation ¶
type TanhActivation struct {
// contains filtered or unexported fields
}
func NewTanh ¶
func NewTanh() TanhActivation
func (TanhActivation) Name ¶
func (t TanhActivation) Name() string