model

package
v0.5.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 21, 2025 License: MIT Imports: 11 Imported by: 0

Documentation

Overview

Package model provides core abstractions and interfaces for machine learning models.

This package defines the fundamental building blocks for machine learning estimators in the SciGo library, including:

  • BaseEstimator: Core estimator with state management and serialization support
  • Model persistence: Save and load trained models using Go's encoding/gob
  • scikit-learn compatibility: Import/export models from Python scikit-learn
  • Streaming interfaces: Support for online learning and incremental training

The BaseEstimator provides a consistent foundation for all ML algorithms with:

  • Fitted state tracking to prevent usage of untrained models
  • Serialization support for model persistence
  • Thread-safe state management
  • Integration with preprocessing pipelines

Example usage:

type MyModel struct {
	model.BaseEstimator
	// model-specific fields
}

func (m *MyModel) Fit(X, y mat.Matrix) error {
	// training logic
	m.SetFitted() // mark as trained
	return nil
}

All models in SciGo embed BaseEstimator to ensure consistent behavior across the entire machine learning pipeline.

Package model provides additional interfaces and types for machine learning models. This file complements the existing interfaces in estimator.go and transformer.go

Package model provides state management for machine learning models.

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func ExportSKLearnModel

func ExportSKLearnModel(modelName string, params interface{}, w io.Writer) error

ExportSKLearnModel はモデルをscikit-learn互換のJSON形式でエクスポート

パラメータ:

  • modelName: モデル名
  • params: モデルパラメータ
  • w: 出力先Writer

戻り値:

  • error: エクスポート失敗時のエラー

func LoadModel

func LoadModel(model interface{}, filename string) error

LoadModel loads a model from a file

Parameters:

  • model: The target model (pointer to struct with embedded BaseEstimator)
  • filename: The file path to load from

Returns:

  • error: Error if loading fails

Example:

var reg linear.Regression
err := model.LoadModel(&reg, "model.gob")

func LoadModelFromReader

func LoadModelFromReader(model interface{}, r io.Reader) error

LoadModelFromReader loads a model from an io.Reader

Parameters:

  • model: The target model (pointer)
  • r: The source Reader

Returns:

  • error: Error if loading fails

func SaveModel

func SaveModel(model interface{}, filename string) error

SaveModel saves a model to a file

Parameters:

  • model: The model to save (struct with embedded BaseEstimator)
  • filename: The file path to save to

Returns:

  • error: Error if saving fails

Example:

var reg linear.Regression
// ... train the model ...
err := model.SaveModel(&reg, "model.gob")

func SaveModelToWriter

func SaveModelToWriter(model interface{}, w io.Writer) error

SaveModelToWriter saves a model to an io.Writer

Parameters:

  • model: The model to save
  • w: The target Writer

Returns:

  • error: Error if saving fails

Types

type AdaptiveLearning

type AdaptiveLearning interface {
	// GetLearningRate returns the current learning rate
	GetLearningRate() float64

	// SetLearningRate sets the learning rate
	SetLearningRate(lr float64)

	// GetLearningRateSchedule returns the learning rate schedule
	// e.g., "constant", "optimal", "invscaling", "adaptive"
	GetLearningRateSchedule() string

	// SetLearningRateSchedule sets the learning rate schedule
	SetLearningRateSchedule(schedule string)
}

AdaptiveLearning is an interface for models that can dynamically adjust learning rates

type BaseEstimator

type BaseEstimator struct {
	// State holds the model's learning state. Public for gob encoding.
	State EstimatorState

	// ModelType identifies the type of model
	ModelType string

	// Version is the model version
	Version string
	// contains filtered or unexported fields
}

BaseEstimator is the base structure for all models

Example

ExampleBaseEstimator demonstrates BaseEstimator state management

package main

import (
	"fmt"

	"github.com/YuminosukeSato/scigo/core/model"
)

func main() {
	// Create a BaseEstimator (typically embedded in actual models)
	estimator := &model.BaseEstimator{}

	// Check initial state
	fmt.Printf("Initially fitted: %t\n", estimator.IsFitted())

	// Mark as fitted
	estimator.SetFitted()
	fmt.Printf("After SetFitted: %t\n", estimator.IsFitted())

	// Reset to unfitted state
	estimator.Reset()
	fmt.Printf("After Reset: %t\n", estimator.IsFitted())

}
Output:

Initially fitted: false
After SetFitted: true
After Reset: false
Example (WorkflowPattern)

ExampleBaseEstimator_workflowPattern demonstrates typical usage pattern

package main

import (
	"fmt"

	"github.com/YuminosukeSato/scigo/core/model"
)

func main() {
	// This example shows how BaseEstimator is typically used in models
	type MyModel struct {
		model.BaseEstimator
		// model-specific fields would go here
	}

	myModel := &MyModel{}

	// Check if model needs training
	if !myModel.IsFitted() {
		fmt.Println("Model needs training")

		// Simulate training process
		// ... training logic would go here ...

		// Mark as fitted after successful training
		myModel.SetFitted()
		fmt.Println("Model trained successfully")
	}

	// Now model is ready for use
	if myModel.IsFitted() {
		fmt.Println("Model is ready for predictions")
	}

}
Output:

Model needs training
Model trained successfully
Model is ready for predictions

func (*BaseEstimator) Clone added in v0.3.0

func (e *BaseEstimator) Clone() *BaseEstimator

Clone creates a new instance of the model (basic implementation)

func (*BaseEstimator) ExportWeights added in v0.3.0

func (e *BaseEstimator) ExportWeights() (*ModelWeights, error)

ExportWeights exports model weights (basic implementation)

func (*BaseEstimator) GetLogger

func (e *BaseEstimator) GetLogger() interface{}

GetLogger returns the logger for this estimator. Returns nil if no logger has been set.

Returns:

  • interface{}: The logger instance, should be type-asserted to log.Logger

Example:

if logger := model.GetLogger(); logger != nil {
    if l, ok := logger.(log.Logger); ok {
        l.Info("Operation completed")
    }
}

func (*BaseEstimator) GetParams added in v0.3.0

func (e *BaseEstimator) GetParams(deep bool) map[string]interface{}

GetParams retrieves the model's hyperparameters (scikit-learn compatible)

func (*BaseEstimator) GetWeightHash added in v0.3.0

func (e *BaseEstimator) GetWeightHash() string

GetWeightHash calculates hash value of weights (for verification)

func (*BaseEstimator) ImportWeights added in v0.3.0

func (e *BaseEstimator) ImportWeights(weights *ModelWeights) error

ImportWeights imports model weights (basic implementation)

func (*BaseEstimator) IsFitted

func (e *BaseEstimator) IsFitted() bool

IsFitted returns whether the model has been fitted with training data.

This method checks the internal state to determine if the estimator has been trained and is ready for prediction or transformation. All models must be fitted before they can be used for predictions.

Returns:

  • bool: true if the model is fitted, false otherwise

Example:

if !model.IsFitted() {
    err := model.Fit(X, y)
    if err != nil {
        log.Fatal(err)
    }
}
predictions, err := model.Predict(X_test)

func (*BaseEstimator) LogDebug

func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})

LogDebug logs a debug-level message if a logger is configured. This is a convenience method for debug logging in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs

func (*BaseEstimator) LogError

func (e *BaseEstimator) LogError(msg string, fields ...interface{})

LogError logs an error-level message if a logger is configured. This is a convenience method for error logging in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs If the first field is an error, it will be handled specially

func (*BaseEstimator) LogInfo

func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})

LogInfo logs an info-level message if a logger is configured. This is a convenience method to avoid repetitive nil checks in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs

func (*BaseEstimator) Reset

func (e *BaseEstimator) Reset()

Reset returns the estimator to its initial untrained state.

This method clears the fitted state, effectively making the model untrained. Useful for reusing model instances with different data or resetting after errors. After reset, the model must be fitted again before use.

Example:

model.Reset()
err := model.Fit(newTrainingData, newLabels)
if err != nil {
    log.Fatal(err)
}

func (*BaseEstimator) SetFitted

func (e *BaseEstimator) SetFitted()

SetFitted marks the estimator as fitted (trained).

This method is called internally by model implementations after successful training to indicate that the model is ready for predictions or transformations. Should only be called by model implementations, not by end users.

Example (within a model's Fit method):

func (m *MyModel) Fit(X, y mat.Matrix) error {
    // ... training logic ...
    m.SetFitted() // Mark as trained
    return nil
}

func (*BaseEstimator) SetLogger

func (e *BaseEstimator) SetLogger(logger interface{})

SetLogger sets the logger for this estimator. This method is typically called during model initialization to provide structured logging capabilities for ML operations.

Parameters:

  • logger: Any logger implementation (typically log.Logger interface)

Example:

import "github.com/YuminosukeSato/scigo/pkg/log"
model.SetLogger(log.GetLoggerWithName("LinearRegression"))

func (*BaseEstimator) SetParams added in v0.3.0

func (e *BaseEstimator) SetParams(params map[string]interface{}) error

SetParams sets the model's hyperparameters (scikit-learn compatible)

type Batch

type Batch struct {
	X mat.Matrix // Feature matrix
	Y mat.Matrix // Target matrix
}

Batch represents a data batch for streaming learning

type BufferedStreaming

type BufferedStreaming interface {
	// SetBufferSize sets the size of streaming buffer
	SetBufferSize(size int)

	// GetBufferSize returns current buffer size
	GetBufferSize() int

	// FlushBuffer forces buffer flush
	FlushBuffer() error
}

BufferedStreaming is a streaming interface with buffering capabilities

type Classifier added in v0.4.0

type Classifier interface {
	Estimator
	Predictor
	Scorer

	// PredictProba returns probability estimates for each class.
	PredictProba(X mat.Matrix) (mat.Matrix, error)

	// Classes returns the unique classes seen during fitting.
	Classes() []int
}

Classifier combines interfaces for classification models.

type ClassifierMixin added in v0.3.0

type ClassifierMixin interface {
	Estimator

	// PredictProba は各クラスの確率を予測
	PredictProba(X mat.Matrix) (mat.Matrix, error)

	// PredictLogProba は各クラスの対数確率を予測
	PredictLogProba(X mat.Matrix) (mat.Matrix, error)

	// DecisionFunction は決定関数の値を計算
	DecisionFunction(X mat.Matrix) (mat.Matrix, error)

	// Classes は学習されたクラスラベルを返す
	Classes() []interface{}

	// NClasses は学習されたクラス数を返す
	NClasses() int
}

ClassifierMixin は分類器のMixinインターフェース

type ClassifierWithPartialFit added in v0.4.0

type ClassifierWithPartialFit interface {
	Classifier
	IncrementalLearner
}

ClassifierWithPartialFit combines interfaces for online classification models.

type ClusterMixin added in v0.3.0

type ClusterMixin interface {
	Fitter

	// FitPredict は学習と予測を同時に実行
	FitPredict(X mat.Matrix) ([]int, error)

	// PredictCluster は新しいデータのクラスタを予測
	PredictCluster(X mat.Matrix) ([]int, error)

	// NClusters はクラスタ数を返す
	NClusters() int
}

ClusterMixin はクラスタリングのMixinインターフェース

type CrossValidatable added in v0.3.0

type CrossValidatable interface {
	// GetCVSplits はクロスバリデーションの分割数を返す
	GetCVSplits() int

	// SetCVSplits はクロスバリデーションの分割数を設定
	SetCVSplits(n int)

	// GetCVScores はクロスバリデーションのスコアを返す
	GetCVScores() []float64
}

CrossValidatable はクロスバリデーション可能なモデルのインターフェース

type Estimator

type Estimator interface {
	Fitter
	Predictor
}

Estimator is an interface for models that can both learn and predict

type EstimatorState

type EstimatorState int

EstimatorState represents the learning state of a model

const (
	// NotFitted indicates the model is not yet trained
	NotFitted EstimatorState = iota
	// Fitted indicates the model has been trained
	Fitted
)

type Fitter

type Fitter interface {
	// Fit trains the model with training data
	Fit(X, y mat.Matrix) error
}

Fitter is an interface for trainable models

type IncrementalEstimator

type IncrementalEstimator interface {
	Estimator

	// PartialFit trains the model incrementally with mini-batches
	// classes specifies all class labels for classification problems (required only on first call)
	// Pass nil for regression problems
	PartialFit(X, y mat.Matrix, classes []int) error

	// NIterations returns the number of training iterations executed
	NIterations() int

	// IsWarmStart returns whether warm start is enabled
	// If true, continues learning from existing parameters when Fit is called
	IsWarmStart() bool

	// SetWarmStart enables/disables warm start
	SetWarmStart(warmStart bool)
}

IncrementalEstimator is an interface for models capable of online learning (incremental learning) Compatible with scikit-learn's partial_fit API

type IncrementalLearner added in v0.4.0

type IncrementalLearner interface {
	// PartialFit performs one epoch of stochastic gradient descent on given samples.
	PartialFit(X mat.Matrix, y mat.Matrix, classes []int) error
}

IncrementalLearner is the interface for models that support incremental learning.

type InverseTransformerMixin added in v0.3.0

type InverseTransformerMixin interface {
	TransformerMixin

	// InverseTransform は変換を逆方向に適用
	InverseTransform(X mat.Matrix) (mat.Matrix, error)
}

InverseTransformerMixin は逆変換可能な変換器のインターフェース

type LinearModel

type LinearModel interface {
	// Weights returns the learned weights (coefficients)
	Weights() []float64
	// Intercept returns the learned intercept
	Intercept() float64
	// Score calculates the model's coefficient of determination (R²)
	Score(X, y mat.Matrix) (float64, error)
}

LinearModel is an interface for linear models

type ModelState added in v0.4.0

type ModelState struct {
	Fitted    bool                   `json:"fitted"`
	NFeatures int                    `json:"n_features,omitempty"`
	NSamples  int                    `json:"n_samples,omitempty"`
	Params    map[string]interface{} `json:"params,omitempty"`
}

ModelState represents the complete state of a model. This can be used for serialization and debugging.

type ModelValidation added in v0.3.0

type ModelValidation struct {
	// ValidateInput は入力データの検証
	ValidateInput func(X mat.Matrix) error

	// ValidateOutput は出力データの検証
	ValidateOutput func(y mat.Matrix) error

	// ValidateWeights は重みの検証
	ValidateWeights func(weights *ModelWeights) error
}

ModelValidation はモデルの検証機能を提供

type ModelWeights added in v0.3.0

type ModelWeights struct {
	// ModelType はモデルの種類(LinearRegression, SGDRegressor等)
	ModelType string `json:"model_type"`

	// Version はモデルのバージョン(互換性チェック用)
	Version string `json:"version"`

	// Coefficients は重み係数
	Coefficients []float64 `json:"coefficients"`

	// Intercept は切片
	Intercept float64 `json:"intercept"`

	// Features は特徴量の名前(オプション)
	Features []string `json:"features,omitempty"`

	// Hyperparameters はモデルのハイパーパラメータ
	Hyperparameters map[string]interface{} `json:"hyperparameters"`

	// Metadata は追加のメタデータ(学習時の統計等)
	Metadata map[string]interface{} `json:"metadata,omitempty"`

	// IsFitted はモデルが学習済みかどうか
	IsFitted bool `json:"is_fitted"`
}

ModelWeights はモデルの重みを表す構造体(シリアライゼーション用)

func (*ModelWeights) Clone added in v0.3.0

func (mw *ModelWeights) Clone() *ModelWeights

Clone はModelWeightsのディープコピーを作成

func (*ModelWeights) FromJSON added in v0.3.0

func (mw *ModelWeights) FromJSON(data []byte) error

FromJSON はJSON形式からModelWeightsをデシリアライズ

func (*ModelWeights) ToJSON added in v0.3.0

func (mw *ModelWeights) ToJSON() ([]byte, error)

ToJSON はModelWeightsをJSON形式にシリアライズ

func (*ModelWeights) Validate added in v0.3.0

func (mw *ModelWeights) Validate() error

Validate はModelWeightsの妥当性を検証

type OnlineEstimator added in v0.4.0

type OnlineEstimator interface {
	Estimator
	IncrementalLearner
}

OnlineEstimator is the interface for models that support online learning. This interface complements the existing StreamingEstimator in streaming.go

type OnlineMetrics

type OnlineMetrics interface {
	// GetLoss returns the current loss value
	GetLoss() float64

	// GetLossHistory returns the history of loss values
	GetLossHistory() []float64

	// GetConverged returns whether the model has converged
	GetConverged() bool
}

OnlineMetrics is an interface for tracking metrics during online learning

type ParallelStreaming

type ParallelStreaming interface {
	// SetWorkers sets the number of workers
	SetWorkers(n int)

	// GetWorkers returns current number of workers
	GetWorkers() int

	// SetBatchParallelism enables/disables intra-batch parallelism
	SetBatchParallelism(enabled bool)
}

ParallelStreaming is an interface for parallel streaming processing

type ParameterGetter added in v0.4.0

type ParameterGetter interface {
	// GetParams returns the model's hyperparameters.
	GetParams() map[string]interface{}
}

ParameterGetter is the interface for models that expose their parameters.

type ParameterSetter added in v0.4.0

type ParameterSetter interface {
	// SetParams sets the model's hyperparameters.
	SetParams(params map[string]interface{}) error
}

ParameterSetter is the interface for models that allow parameter modification.

type PartialFitMixin added in v0.3.0

type PartialFitMixin interface {
	// PartialFit はミニバッチで逐次学習
	PartialFit(X, y mat.Matrix, classes []int) error

	// NIterations は学習イテレーション数を返す
	NIterations() int

	// IsWarmStart はウォームスタートが有効かどうか
	IsWarmStart() bool

	// SetWarmStart はウォームスタートの有効/無効を設定
	SetWarmStart(warmStart bool)
}

PartialFitMixin は逐次学習可能なモデルのインターフェース

type Persistable added in v0.4.0

type Persistable interface {
	// Save saves the model to a file.
	Save(path string) error

	// Load loads the model from a file.
	Load(path string) error
}

Persistable is the interface for models that can be saved and loaded.

type PipelineCompatible added in v0.3.0

type PipelineCompatible interface {
	SKLearnCompatible

	// GetInputDim は入力次元数を返す
	GetInputDim() int

	// GetOutputDim は出力次元数を返す
	GetOutputDim() int

	// RequiresFit は学習が必要かどうか
	RequiresFit() bool
}

PipelineCompatible はパイプラインで使用可能なモデルのインターフェース

type Predictor

type Predictor interface {
	// Predict performs predictions on input data
	Predict(X mat.Matrix) (mat.Matrix, error)
}

Predictor is an interface for predictive models

type Regressor added in v0.4.0

type Regressor interface {
	Estimator
	Predictor
	Scorer
}

Regressor combines interfaces for regression models.

type RegressorMixin added in v0.3.0

type RegressorMixin interface {
	Estimator

	// Score は決定係数R²を計算
	Score(X, y mat.Matrix) (float64, error)
}

RegressorMixin は回帰器のMixinインターフェース

type RegressorWithPartialFit added in v0.4.0

type RegressorWithPartialFit interface {
	Regressor
	IncrementalLearner
}

RegressorWithPartialFit combines interfaces for online regression models.

type SKLearnCompatible added in v0.3.0

type SKLearnCompatible interface {
	// GetParams はモデルのハイパーパラメータを取得
	GetParams(deep bool) map[string]interface{}

	// SetParams はモデルのハイパーパラメータを設定
	SetParams(params map[string]interface{}) error

	// Clone はモデルの新しいインスタンスを同じパラメータで作成
	Clone() SKLearnCompatible
}

SKLearnCompatible はscikit-learn互換のインターフェース

type SKLearnLinearRegressionParams

type SKLearnLinearRegressionParams struct {
	Coefficients []float64 `json:"coefficients"` // 係数(重み)
	Intercept    float64   `json:"intercept"`    // 切片
	NFeatures    int       `json:"n_features"`   // 特徴量の数
}

SKLearnLinearRegressionParams は線形回帰モデルのパラメータ

func LoadLinearRegressionParams

func LoadLinearRegressionParams(model *SKLearnModel) (*SKLearnLinearRegressionParams, error)

LoadLinearRegressionParams はLinearRegressionのパラメータを読み込む

パラメータ:

  • model: SKLearnModelインスタンス

戻り値:

  • *SKLearnLinearRegressionParams: パラメータ
  • error: パース失敗時のエラー

type SKLearnModel

type SKLearnModel struct {
	ModelSpec SKLearnModelSpec `json:"model_spec"`
	Params    json.RawMessage  `json:"params"`
}

SKLearnModel はscikit-learnからエクスポートされたモデル

func LoadSKLearnModelFromFile

func LoadSKLearnModelFromFile(filename string) (*SKLearnModel, error)

LoadSKLearnModelFromFile はファイルからscikit-learnモデルを読み込む

パラメータ:

  • filename: JSONファイルのパス

戻り値:

  • *SKLearnModel: 読み込まれたモデル
  • error: 読み込みエラー

使用例:

model, err := model.LoadSKLearnModelFromFile("sklearn_model.json")
if err != nil {
    log.Fatal(err)
}

func LoadSKLearnModelFromReader

func LoadSKLearnModelFromReader(r io.Reader) (*SKLearnModel, error)

LoadSKLearnModelFromReader はReaderからscikit-learnモデルを読み込む

パラメータ:

  • r: JSONデータを含むReader

戻り値:

  • *SKLearnModel: 読み込まれたモデル
  • error: 読み込みエラー

type SKLearnModelSpec

type SKLearnModelSpec struct {
	Name           string `json:"name"`                      // モデル名 (e.g., "LinearRegression")
	FormatVersion  string `json:"format_version"`            // フォーマットバージョン
	SKLearnVersion string `json:"sklearn_version,omitempty"` // scikit-learnのバージョン
}

SKLearnModelSpec はscikit-learnモデルのメタデータ

type Scorer added in v0.4.0

type Scorer interface {
	// Score returns the coefficient of determination R^2 of the prediction.
	Score(X mat.Matrix, y mat.Matrix) (float64, error)
}

Scorer is the interface for models that can compute a score.

type StateManager added in v0.4.0

type StateManager struct {
	Fitted bool // Public for gob encoding

	// Optional metadata - Public for gob encoding
	NFeatures int
	NSamples  int
	// contains filtered or unexported fields
}

StateManager manages the fitted state of a model in a thread-safe manner. It replaces the BaseEstimator embedding pattern with composition.

func NewStateManager added in v0.4.0

func NewStateManager() *StateManager

NewStateManager creates a new StateManager instance.

func (*StateManager) GetDimensions added in v0.4.0

func (s *StateManager) GetDimensions() (nFeatures, nSamples int)

GetDimensions returns the number of features and samples seen during fitting.

func (*StateManager) GetState added in v0.4.0

func (s *StateManager) GetState() ModelState

GetState returns the current state as a ModelState struct.

func (*StateManager) IsFitted added in v0.4.0

func (s *StateManager) IsFitted() bool

IsFitted returns whether the model has been fitted.

func (*StateManager) RequireFitted added in v0.4.0

func (s *StateManager) RequireFitted() error

RequireFitted returns an error if the model has not been fitted.

func (*StateManager) Reset added in v0.4.0

func (s *StateManager) Reset()

Reset resets the fitted state.

func (*StateManager) SetDimensions added in v0.4.0

func (s *StateManager) SetDimensions(nFeatures, nSamples int)

SetDimensions sets the number of features and samples seen during fitting.

func (*StateManager) SetFitted added in v0.4.0

func (s *StateManager) SetFitted()

SetFitted marks the model as fitted.

func (*StateManager) SetState added in v0.4.0

func (s *StateManager) SetState(state ModelState)

SetState sets the state from a ModelState struct.

func (*StateManager) WithState added in v0.4.0

func (s *StateManager) WithState(fn func() error) error

WithState is a helper function that executes a function with the state locked for reading.

func (*StateManager) WithStateMut added in v0.4.0

func (s *StateManager) WithStateMut(fn func() error) error

WithStateMut is a helper function that executes a function with the state locked for writing.

type StreamingEstimator

type StreamingEstimator interface {
	IncrementalEstimator

	// FitStream trains the model from a data stream
	// Continues learning until the context is canceled or the channel is closed
	FitStream(ctx context.Context, dataChan <-chan *Batch) error

	// PredictStream performs real-time predictions on input stream
	// Output channel is closed when input channel is closed
	PredictStream(ctx context.Context, inputChan <-chan mat.Matrix) <-chan mat.Matrix

	// FitPredictStream performs learning and prediction simultaneously
	// Returns predictions while training on new data (test-then-train approach)
	FitPredictStream(ctx context.Context, dataChan <-chan *Batch) <-chan mat.Matrix
}

StreamingEstimator provides channel-based streaming learning interface

type StreamingMetrics

type StreamingMetrics interface {
	OnlineMetrics

	// GetThroughput returns current throughput (samples/second)
	GetThroughput() float64

	// GetProcessedSamples returns total number of processed samples
	GetProcessedSamples() int64

	// GetAverageLatency returns average latency in milliseconds
	GetAverageLatency() float64

	// GetMemoryUsage returns current memory usage in bytes
	GetMemoryUsage() int64
}

StreamingMetrics provides metrics during streaming learning

type Transformer

type Transformer interface {
	// Fit learns parameters necessary for transformation
	Fit(X mat.Matrix) error

	// Transform transforms data
	Transform(X mat.Matrix) (mat.Matrix, error)

	// FitTransform executes Fit and Transform simultaneously
	FitTransform(X mat.Matrix) (mat.Matrix, error)
}

Transformer is an interface for data transformation

type TransformerMixin added in v0.3.0

type TransformerMixin interface {
	// Transform はデータを変換
	Transform(X mat.Matrix) (mat.Matrix, error)

	// FitTransform は学習と変換を同時に実行
	FitTransform(X mat.Matrix) (mat.Matrix, error)
}

TransformerMixin は変換器のMixinインターフェース

type WeightExporter added in v0.3.0

type WeightExporter interface {
	// ExportWeights はモデルの重みをエクスポート
	ExportWeights() (*ModelWeights, error)

	// ImportWeights はモデルの重みをインポート
	ImportWeights(weights *ModelWeights) error

	// GetWeightHash は重みのハッシュ値を計算(検証用)
	GetWeightHash() string
}

WeightExporter は重みをエクスポート可能なモデルのインターフェース

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL