ctr

package
v0.5.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 2, 2026 License: Apache-2.0 Imports: 30 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func AUC

func AUC(posPrediction, negPrediction []float32) float32

func Accuracy

func Accuracy(posPrediction, negPrediction []float32) float32

func LoadLibFMFile

func LoadLibFMFile(path string) (features [][]lo.Tuple2[int32, float32], targets []float32, maxLabel int32, err error)

LoadLibFMFile loads libFM format file.

func MarshalModel

func MarshalModel(w io.Writer, m FactorizationMachines) error

func Precision

func Precision(posPrediction, negPrediction []float32) float32

func Recall

func Recall(posPrediction, _ []float32) float32

Types

type AFM added in v0.5.2

type AFM struct {
	BaseFactorizationMachines

	// parameters
	B *nn.Tensor
	W nn.Layer
	V nn.Layer
	A []nn.Layer
	E []nn.Layer
	// contains filtered or unexported fields
}

func NewAFM added in v0.5.2

func NewAFM(params model.Params) *AFM

func (*AFM) BatchInternalPredict added in v0.5.2

func (fm *AFM) BatchInternalPredict(x []lo.Tuple2[[]int32, []float32], e [][][]float32, jobs int) []float32

func (*AFM) BatchPredict added in v0.5.2

func (fm *AFM) BatchPredict(inputs []lo.Tuple4[string, string, []Label, []Label], embeddings [][]Embedding, jobs int) []float32

func (*AFM) Clear added in v0.5.2

func (fm *AFM) Clear()

func (*AFM) Fit added in v0.5.2

func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, config *FitConfig) Score

func (*AFM) Forward added in v0.5.2

func (fm *AFM) Forward(indices, values *nn.Tensor, embeddings []*nn.Tensor, jobs int) *nn.Tensor

func (*AFM) Init added in v0.5.2

func (fm *AFM) Init(trainSet dataset.CTRSplit)

func (*AFM) InternalPredict added in v0.5.2

func (fm *AFM) InternalPredict(_ []int32, _ []float32) float32

func (*AFM) Invalid added in v0.5.2

func (fm *AFM) Invalid() bool

func (*AFM) Marshal added in v0.5.2

func (fm *AFM) Marshal(w io.Writer) error

func (*AFM) Parameters added in v0.5.2

func (fm *AFM) Parameters() []*nn.Tensor

func (*AFM) Predict added in v0.5.2

func (fm *AFM) Predict(_, _ string, _, _ []Label) float32

func (*AFM) SetParams added in v0.5.2

func (fm *AFM) SetParams(params model.Params)

func (*AFM) SuggestParams added in v0.5.2

func (fm *AFM) SuggestParams(trial goptuna.Trial) model.Params

func (*AFM) Unmarshal added in v0.5.2

func (fm *AFM) Unmarshal(r io.Reader) error

type BaseFactorizationMachines

type BaseFactorizationMachines struct {
	model.BaseModel
	Index dataset.UnifiedIndex
}

func (*BaseFactorizationMachines) Init

func (b *BaseFactorizationMachines) Init(trainSet dataset.CTRSplit)

type BatchInference

type BatchInference interface {
	BatchPredict(inputs []lo.Tuple4[string, string, []Label, []Label], e [][]Embedding, jobs int) []float32
	BatchInternalPredict(x []lo.Tuple2[[]int32, []float32], e [][][]float32, jobs int) []float32
}

type Dataset

type Dataset struct {
	Index                  dataset.UnifiedIndex
	UserLabels             [][]lo.Tuple2[int32, float32]
	ItemLabels             [][]lo.Tuple2[int32, float32]
	ContextLabels          [][]lo.Tuple2[int32, float32]
	Users                  []int32
	Items                  []int32
	Target                 []float32
	ItemEmbeddings         [][][]float32 // Index by row id, embedding id, embedding dimension
	ItemEmbeddingDimension []int
	ItemEmbeddingIndex     *dataset.Index
	PositiveCount          int
	NegativeCount          int
}

Dataset for click-through-rate models.

func LoadDataFromBuiltIn

func LoadDataFromBuiltIn(name string) (train, test *Dataset, err error)

LoadDataFromBuiltIn loads built-in dataset.

func (*Dataset) Count

func (dataset *Dataset) Count() int

Count returns the number of samples.

func (*Dataset) CountContextLabels

func (dataset *Dataset) CountContextLabels() int

func (*Dataset) CountItemLabels

func (dataset *Dataset) CountItemLabels() int

func (*Dataset) CountItems

func (dataset *Dataset) CountItems() int

CountItems returns the number of items.

func (*Dataset) CountNegative

func (dataset *Dataset) CountNegative() int

func (*Dataset) CountPositive

func (dataset *Dataset) CountPositive() int

func (*Dataset) CountUserLabels

func (dataset *Dataset) CountUserLabels() int

func (*Dataset) CountUsers

func (dataset *Dataset) CountUsers() int

CountUsers returns the number of users.

func (*Dataset) Get

func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32)

Get returns the i-th sample.

func (*Dataset) GetIndex

func (dataset *Dataset) GetIndex() dataset.UnifiedIndex

func (*Dataset) GetItemEmbeddingDim added in v0.5.2

func (dataset *Dataset) GetItemEmbeddingDim() []int

func (*Dataset) GetItemEmbeddingIndex added in v0.5.2

func (dataset *Dataset) GetItemEmbeddingIndex() *dataset.Index

func (*Dataset) GetTarget

func (dataset *Dataset) GetTarget(i int) float32

func (*Dataset) Split

func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset)

Split a dataset to training set and test set.

type Embedding

type Embedding struct {
	Name  string
	Value []float32
}

func ConvertEmbeddings

func ConvertEmbeddings(o any) []Embedding

type FactorizationMachines

type FactorizationMachines interface {
	model.Model
	Predict(userId, itemId string, userFeatures, itemFeatures []Label) float32
	InternalPredict(x []int32, values []float32) float32
	Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, config *FitConfig) Score
	Marshal(w io.Writer) error
}

func UnmarshalModel

func UnmarshalModel(r io.Reader) (FactorizationMachines, error)

type FitConfig

type FitConfig struct {
	Jobs     int
	Verbose  int
	Patience int
}

func NewFitConfig

func NewFitConfig() *FitConfig

func (*FitConfig) LoadDefaultIfNil

func (config *FitConfig) LoadDefaultIfNil() *FitConfig

func (*FitConfig) SetJobs

func (config *FitConfig) SetJobs(jobs int) *FitConfig

func (*FitConfig) SetPatience

func (config *FitConfig) SetPatience(patience int) *FitConfig

func (*FitConfig) SetVerbose

func (config *FitConfig) SetVerbose(verbose int) *FitConfig

type Label

type Label struct {
	Name  string
	Value float32
}

func ConvertLabels

func ConvertLabels(o any) []Label

type ModelCreator

type ModelCreator func() FactorizationMachines

type ModelSearch

type ModelSearch struct {
	// contains filtered or unexported fields
}

func NewModelSearch

func NewModelSearch(models map[string]ModelCreator, trainSet, testSet dataset.CTRSplit, config *FitConfig) *ModelSearch

func (*ModelSearch) Objective

func (ms *ModelSearch) Objective(trial goptuna.Trial) (float64, error)

func (*ModelSearch) Result

func (ms *ModelSearch) Result() meta.Model[Score]

func (*ModelSearch) WithContext

func (ms *ModelSearch) WithContext(ctx context.Context) *ModelSearch

func (*ModelSearch) WithSpan

func (ms *ModelSearch) WithSpan(span *monitor.Span) *ModelSearch

type Score

type Score struct {
	RMSE      float32
	Precision float32
	Recall    float32
	Accuracy  float32
	AUC       float32
}

func EvaluateClassification

func EvaluateClassification(estimator FactorizationMachines, testSet dataset.CTRSplit, jobs int) Score

EvaluateClassification evaluates factorization machines in classification task.

func EvaluateRegression

func EvaluateRegression(estimator FactorizationMachines, testSet *Dataset) Score

EvaluateRegression evaluates factorization machines in regression task.

func (Score) BetterThan

func (score Score) BetterThan(s Score) bool

func (Score) GetValue

func (score Score) GetValue() float32

func (Score) ZapFields

func (score Score) ZapFields() []zap.Field

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL