Documentation
¶
Overview ¶
Package model provides core abstractions and interfaces for machine learning models.
This package defines the fundamental building blocks for machine learning estimators in the SciGo library, including:
- BaseEstimator: Core estimator with state management and serialization support
- Model persistence: Save and load trained models using Go's encoding/gob
- scikit-learn compatibility: Import/export models from Python scikit-learn
- Streaming interfaces: Support for online learning and incremental training
The BaseEstimator provides a consistent foundation for all ML algorithms with:
- Fitted state tracking to prevent usage of untrained models
- Serialization support for model persistence
- Thread-safe state management
- Integration with preprocessing pipelines
Example usage:
type MyModel struct {
model.BaseEstimator
// model-specific fields
}
func (m *MyModel) Fit(X, y mat.Matrix) error {
// training logic
m.SetFitted() // mark as trained
return nil
}
All models in SciGo embed BaseEstimator to ensure consistent behavior across the entire machine learning pipeline.
Package model provides additional interfaces and types for machine learning models. This file complements the existing interfaces in estimator.go and transformer.go
Package model provides state management for machine learning models.
Index ¶
- func ExportSKLearnModel(modelName string, params interface{}, w io.Writer) error
- func LoadModel(model interface{}, filename string) error
- func LoadModelFromReader(model interface{}, r io.Reader) error
- func SaveModel(model interface{}, filename string) error
- func SaveModelToWriter(model interface{}, w io.Writer) error
- type AdaptiveLearning
- type BaseEstimator
- func (e *BaseEstimator) Clone() *BaseEstimator
- func (e *BaseEstimator) ExportWeights() (*ModelWeights, error)
- func (e *BaseEstimator) GetLogger() interface{}
- func (e *BaseEstimator) GetParams(deep bool) map[string]interface{}
- func (e *BaseEstimator) GetWeightHash() string
- func (e *BaseEstimator) ImportWeights(weights *ModelWeights) error
- func (e *BaseEstimator) IsFitted() bool
- func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})
- func (e *BaseEstimator) LogError(msg string, fields ...interface{})
- func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})
- func (e *BaseEstimator) Reset()
- func (e *BaseEstimator) SetFitted()
- func (e *BaseEstimator) SetLogger(logger interface{})
- func (e *BaseEstimator) SetParams(params map[string]interface{}) error
- type Batch
- type BufferedStreaming
- type Classifier
- type ClassifierMixin
- type ClassifierWithPartialFit
- type ClusterMixin
- type CrossValidatable
- type Estimator
- type EstimatorState
- type Fitter
- type IncrementalEstimator
- type IncrementalLearner
- type InverseTransformerMixin
- type LinearModel
- type ModelState
- type ModelValidation
- type ModelWeights
- type OnlineEstimator
- type OnlineMetrics
- type ParallelStreaming
- type ParameterGetter
- type ParameterSetter
- type PartialFitMixin
- type Persistable
- type PipelineCompatible
- type Predictor
- type Regressor
- type RegressorMixin
- type RegressorWithPartialFit
- type SKLearnCompatible
- type SKLearnLinearRegressionParams
- type SKLearnModel
- type SKLearnModelSpec
- type Scorer
- type StateManager
- func (s *StateManager) GetDimensions() (nFeatures, nSamples int)
- func (s *StateManager) GetState() ModelState
- func (s *StateManager) IsFitted() bool
- func (s *StateManager) RequireFitted() error
- func (s *StateManager) Reset()
- func (s *StateManager) SetDimensions(nFeatures, nSamples int)
- func (s *StateManager) SetFitted()
- func (s *StateManager) SetState(state ModelState)
- func (s *StateManager) WithState(fn func() error) error
- func (s *StateManager) WithStateMut(fn func() error) error
- type StreamingEstimator
- type StreamingMetrics
- type Transformer
- type TransformerMixin
- type WeightExporter
Examples ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ExportSKLearnModel ¶
ExportSKLearnModel はモデルをscikit-learn互換のJSON形式でエクスポート
パラメータ:
- modelName: モデル名
- params: モデルパラメータ
- w: 出力先Writer
戻り値:
- error: エクスポート失敗時のエラー
func LoadModel ¶
LoadModel loads a model from a file
Parameters:
- model: The target model (pointer to struct with embedded BaseEstimator)
- filename: The file path to load from
Returns:
- error: Error if loading fails
Example:
var reg linear.Regression err := model.LoadModel(®, "model.gob")
func LoadModelFromReader ¶
LoadModelFromReader loads a model from an io.Reader
Parameters:
- model: The target model (pointer)
- r: The source Reader
Returns:
- error: Error if loading fails
func SaveModel ¶
SaveModel saves a model to a file
Parameters:
- model: The model to save (struct with embedded BaseEstimator)
- filename: The file path to save to
Returns:
- error: Error if saving fails
Example:
var reg linear.Regression // ... train the model ... err := model.SaveModel(®, "model.gob")
func SaveModelToWriter ¶
SaveModelToWriter saves a model to an io.Writer
Parameters:
- model: The model to save
- w: The target Writer
Returns:
- error: Error if saving fails
Types ¶
type AdaptiveLearning ¶
type AdaptiveLearning interface {
// GetLearningRate returns the current learning rate
GetLearningRate() float64
// SetLearningRate sets the learning rate
SetLearningRate(lr float64)
// GetLearningRateSchedule returns the learning rate schedule
// e.g., "constant", "optimal", "invscaling", "adaptive"
GetLearningRateSchedule() string
// SetLearningRateSchedule sets the learning rate schedule
SetLearningRateSchedule(schedule string)
}
AdaptiveLearning is an interface for models that can dynamically adjust learning rates
type BaseEstimator ¶
type BaseEstimator struct {
// State holds the model's learning state. Public for gob encoding.
State EstimatorState
// ModelType identifies the type of model
ModelType string
// Version is the model version
Version string
// contains filtered or unexported fields
}
BaseEstimator is the base structure for all models
Example ¶
ExampleBaseEstimator demonstrates BaseEstimator state management
package main
import (
"fmt"
"github.com/YuminosukeSato/scigo/core/model"
)
func main() {
// Create a BaseEstimator (typically embedded in actual models)
estimator := &model.BaseEstimator{}
// Check initial state
fmt.Printf("Initially fitted: %t\n", estimator.IsFitted())
// Mark as fitted
estimator.SetFitted()
fmt.Printf("After SetFitted: %t\n", estimator.IsFitted())
// Reset to unfitted state
estimator.Reset()
fmt.Printf("After Reset: %t\n", estimator.IsFitted())
}
Output: Initially fitted: false After SetFitted: true After Reset: false
Example (WorkflowPattern) ¶
ExampleBaseEstimator_workflowPattern demonstrates typical usage pattern
package main
import (
"fmt"
"github.com/YuminosukeSato/scigo/core/model"
)
func main() {
// This example shows how BaseEstimator is typically used in models
type MyModel struct {
model.BaseEstimator
// model-specific fields would go here
}
myModel := &MyModel{}
// Check if model needs training
if !myModel.IsFitted() {
fmt.Println("Model needs training")
// Simulate training process
// ... training logic would go here ...
// Mark as fitted after successful training
myModel.SetFitted()
fmt.Println("Model trained successfully")
}
// Now model is ready for use
if myModel.IsFitted() {
fmt.Println("Model is ready for predictions")
}
}
Output: Model needs training Model trained successfully Model is ready for predictions
func (*BaseEstimator) Clone ¶ added in v0.3.0
func (e *BaseEstimator) Clone() *BaseEstimator
Clone creates a new instance of the model (basic implementation)
func (*BaseEstimator) ExportWeights ¶ added in v0.3.0
func (e *BaseEstimator) ExportWeights() (*ModelWeights, error)
ExportWeights exports model weights (basic implementation)
func (*BaseEstimator) GetLogger ¶
func (e *BaseEstimator) GetLogger() interface{}
GetLogger returns the logger for this estimator. Returns nil if no logger has been set.
Returns:
- interface{}: The logger instance, should be type-asserted to log.Logger
Example:
if logger := model.GetLogger(); logger != nil {
if l, ok := logger.(log.Logger); ok {
l.Info("Operation completed")
}
}
func (*BaseEstimator) GetParams ¶ added in v0.3.0
func (e *BaseEstimator) GetParams(deep bool) map[string]interface{}
GetParams retrieves the model's hyperparameters (scikit-learn compatible)
func (*BaseEstimator) GetWeightHash ¶ added in v0.3.0
func (e *BaseEstimator) GetWeightHash() string
GetWeightHash calculates hash value of weights (for verification)
func (*BaseEstimator) ImportWeights ¶ added in v0.3.0
func (e *BaseEstimator) ImportWeights(weights *ModelWeights) error
ImportWeights imports model weights (basic implementation)
func (*BaseEstimator) IsFitted ¶
func (e *BaseEstimator) IsFitted() bool
IsFitted returns whether the model has been fitted with training data.
This method checks the internal state to determine if the estimator has been trained and is ready for prediction or transformation. All models must be fitted before they can be used for predictions.
Returns:
- bool: true if the model is fitted, false otherwise
Example:
if !model.IsFitted() {
err := model.Fit(X, y)
if err != nil {
log.Fatal(err)
}
}
predictions, err := model.Predict(X_test)
func (*BaseEstimator) LogDebug ¶
func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})
LogDebug logs a debug-level message if a logger is configured. This is a convenience method for debug logging in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs
func (*BaseEstimator) LogError ¶
func (e *BaseEstimator) LogError(msg string, fields ...interface{})
LogError logs an error-level message if a logger is configured. This is a convenience method for error logging in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs If the first field is an error, it will be handled specially
func (*BaseEstimator) LogInfo ¶
func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})
LogInfo logs an info-level message if a logger is configured. This is a convenience method to avoid repetitive nil checks in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs
func (*BaseEstimator) Reset ¶
func (e *BaseEstimator) Reset()
Reset returns the estimator to its initial untrained state.
This method clears the fitted state, effectively making the model untrained. Useful for reusing model instances with different data or resetting after errors. After reset, the model must be fitted again before use.
Example:
model.Reset()
err := model.Fit(newTrainingData, newLabels)
if err != nil {
log.Fatal(err)
}
func (*BaseEstimator) SetFitted ¶
func (e *BaseEstimator) SetFitted()
SetFitted marks the estimator as fitted (trained).
This method is called internally by model implementations after successful training to indicate that the model is ready for predictions or transformations. Should only be called by model implementations, not by end users.
Example (within a model's Fit method):
func (m *MyModel) Fit(X, y mat.Matrix) error {
// ... training logic ...
m.SetFitted() // Mark as trained
return nil
}
func (*BaseEstimator) SetLogger ¶
func (e *BaseEstimator) SetLogger(logger interface{})
SetLogger sets the logger for this estimator. This method is typically called during model initialization to provide structured logging capabilities for ML operations.
Parameters:
- logger: Any logger implementation (typically log.Logger interface)
Example:
import "github.com/YuminosukeSato/scigo/pkg/log"
model.SetLogger(log.GetLoggerWithName("LinearRegression"))
func (*BaseEstimator) SetParams ¶ added in v0.3.0
func (e *BaseEstimator) SetParams(params map[string]interface{}) error
SetParams sets the model's hyperparameters (scikit-learn compatible)
type BufferedStreaming ¶
type BufferedStreaming interface {
// SetBufferSize sets the size of streaming buffer
SetBufferSize(size int)
// GetBufferSize returns current buffer size
GetBufferSize() int
// FlushBuffer forces buffer flush
FlushBuffer() error
}
BufferedStreaming is a streaming interface with buffering capabilities
type Classifier ¶ added in v0.4.0
type Classifier interface {
Estimator
Predictor
Scorer
// PredictProba returns probability estimates for each class.
PredictProba(X mat.Matrix) (mat.Matrix, error)
// Classes returns the unique classes seen during fitting.
Classes() []int
}
Classifier combines interfaces for classification models.
type ClassifierMixin ¶ added in v0.3.0
type ClassifierMixin interface {
Estimator
// PredictProba は各クラスの確率を予測
PredictProba(X mat.Matrix) (mat.Matrix, error)
// PredictLogProba は各クラスの対数確率を予測
PredictLogProba(X mat.Matrix) (mat.Matrix, error)
// DecisionFunction は決定関数の値を計算
DecisionFunction(X mat.Matrix) (mat.Matrix, error)
// Classes は学習されたクラスラベルを返す
Classes() []interface{}
// NClasses は学習されたクラス数を返す
NClasses() int
}
ClassifierMixin は分類器のMixinインターフェース
type ClassifierWithPartialFit ¶ added in v0.4.0
type ClassifierWithPartialFit interface {
Classifier
IncrementalLearner
}
ClassifierWithPartialFit combines interfaces for online classification models.
type ClusterMixin ¶ added in v0.3.0
type ClusterMixin interface {
Fitter
// FitPredict は学習と予測を同時に実行
FitPredict(X mat.Matrix) ([]int, error)
// PredictCluster は新しいデータのクラスタを予測
PredictCluster(X mat.Matrix) ([]int, error)
// NClusters はクラスタ数を返す
NClusters() int
}
ClusterMixin はクラスタリングのMixinインターフェース
type CrossValidatable ¶ added in v0.3.0
type CrossValidatable interface {
// GetCVSplits はクロスバリデーションの分割数を返す
GetCVSplits() int
// SetCVSplits はクロスバリデーションの分割数を設定
SetCVSplits(n int)
// GetCVScores はクロスバリデーションのスコアを返す
GetCVScores() []float64
}
CrossValidatable はクロスバリデーション可能なモデルのインターフェース
type EstimatorState ¶
type EstimatorState int
EstimatorState represents the learning state of a model
const ( // NotFitted indicates the model is not yet trained NotFitted EstimatorState = iota // Fitted indicates the model has been trained Fitted )
type IncrementalEstimator ¶
type IncrementalEstimator interface {
Estimator
// PartialFit trains the model incrementally with mini-batches
// classes specifies all class labels for classification problems (required only on first call)
// Pass nil for regression problems
PartialFit(X, y mat.Matrix, classes []int) error
// NIterations returns the number of training iterations executed
NIterations() int
// IsWarmStart returns whether warm start is enabled
// If true, continues learning from existing parameters when Fit is called
IsWarmStart() bool
// SetWarmStart enables/disables warm start
SetWarmStart(warmStart bool)
}
IncrementalEstimator is an interface for models capable of online learning (incremental learning) Compatible with scikit-learn's partial_fit API
type IncrementalLearner ¶ added in v0.4.0
type IncrementalLearner interface {
// PartialFit performs one epoch of stochastic gradient descent on given samples.
PartialFit(X mat.Matrix, y mat.Matrix, classes []int) error
}
IncrementalLearner is the interface for models that support incremental learning.
type InverseTransformerMixin ¶ added in v0.3.0
type InverseTransformerMixin interface {
TransformerMixin
// InverseTransform は変換を逆方向に適用
InverseTransform(X mat.Matrix) (mat.Matrix, error)
}
InverseTransformerMixin は逆変換可能な変換器のインターフェース
type LinearModel ¶
type LinearModel interface {
// Weights returns the learned weights (coefficients)
Weights() []float64
// Intercept returns the learned intercept
Intercept() float64
// Score calculates the model's coefficient of determination (R²)
Score(X, y mat.Matrix) (float64, error)
}
LinearModel is an interface for linear models
type ModelState ¶ added in v0.4.0
type ModelState struct {
Fitted bool `json:"fitted"`
NFeatures int `json:"n_features,omitempty"`
NSamples int `json:"n_samples,omitempty"`
Params map[string]interface{} `json:"params,omitempty"`
}
ModelState represents the complete state of a model. This can be used for serialization and debugging.
type ModelValidation ¶ added in v0.3.0
type ModelValidation struct {
// ValidateInput は入力データの検証
ValidateInput func(X mat.Matrix) error
// ValidateOutput は出力データの検証
ValidateOutput func(y mat.Matrix) error
// ValidateWeights は重みの検証
ValidateWeights func(weights *ModelWeights) error
}
ModelValidation はモデルの検証機能を提供
type ModelWeights ¶ added in v0.3.0
type ModelWeights struct {
// ModelType はモデルの種類(LinearRegression, SGDRegressor等)
ModelType string `json:"model_type"`
// Version はモデルのバージョン(互換性チェック用)
Version string `json:"version"`
// Coefficients は重み係数
Coefficients []float64 `json:"coefficients"`
// Intercept は切片
Intercept float64 `json:"intercept"`
// Features は特徴量の名前(オプション)
Features []string `json:"features,omitempty"`
// Hyperparameters はモデルのハイパーパラメータ
Hyperparameters map[string]interface{} `json:"hyperparameters"`
// Metadata は追加のメタデータ(学習時の統計等)
Metadata map[string]interface{} `json:"metadata,omitempty"`
// IsFitted はモデルが学習済みかどうか
IsFitted bool `json:"is_fitted"`
}
ModelWeights はモデルの重みを表す構造体(シリアライゼーション用)
func (*ModelWeights) Clone ¶ added in v0.3.0
func (mw *ModelWeights) Clone() *ModelWeights
Clone はModelWeightsのディープコピーを作成
func (*ModelWeights) FromJSON ¶ added in v0.3.0
func (mw *ModelWeights) FromJSON(data []byte) error
FromJSON はJSON形式からModelWeightsをデシリアライズ
func (*ModelWeights) ToJSON ¶ added in v0.3.0
func (mw *ModelWeights) ToJSON() ([]byte, error)
ToJSON はModelWeightsをJSON形式にシリアライズ
func (*ModelWeights) Validate ¶ added in v0.3.0
func (mw *ModelWeights) Validate() error
Validate はModelWeightsの妥当性を検証
type OnlineEstimator ¶ added in v0.4.0
type OnlineEstimator interface {
Estimator
IncrementalLearner
}
OnlineEstimator is the interface for models that support online learning. This interface complements the existing StreamingEstimator in streaming.go
type OnlineMetrics ¶
type OnlineMetrics interface {
// GetLoss returns the current loss value
GetLoss() float64
// GetLossHistory returns the history of loss values
GetLossHistory() []float64
// GetConverged returns whether the model has converged
GetConverged() bool
}
OnlineMetrics is an interface for tracking metrics during online learning
type ParallelStreaming ¶
type ParallelStreaming interface {
// SetWorkers sets the number of workers
SetWorkers(n int)
// GetWorkers returns current number of workers
GetWorkers() int
// SetBatchParallelism enables/disables intra-batch parallelism
SetBatchParallelism(enabled bool)
}
ParallelStreaming is an interface for parallel streaming processing
type ParameterGetter ¶ added in v0.4.0
type ParameterGetter interface {
// GetParams returns the model's hyperparameters.
GetParams() map[string]interface{}
}
ParameterGetter is the interface for models that expose their parameters.
type ParameterSetter ¶ added in v0.4.0
type ParameterSetter interface {
// SetParams sets the model's hyperparameters.
SetParams(params map[string]interface{}) error
}
ParameterSetter is the interface for models that allow parameter modification.
type PartialFitMixin ¶ added in v0.3.0
type PartialFitMixin interface {
// PartialFit はミニバッチで逐次学習
PartialFit(X, y mat.Matrix, classes []int) error
// NIterations は学習イテレーション数を返す
NIterations() int
// IsWarmStart はウォームスタートが有効かどうか
IsWarmStart() bool
// SetWarmStart はウォームスタートの有効/無効を設定
SetWarmStart(warmStart bool)
}
PartialFitMixin は逐次学習可能なモデルのインターフェース
type Persistable ¶ added in v0.4.0
type Persistable interface {
// Save saves the model to a file.
Save(path string) error
// Load loads the model from a file.
Load(path string) error
}
Persistable is the interface for models that can be saved and loaded.
type PipelineCompatible ¶ added in v0.3.0
type PipelineCompatible interface {
SKLearnCompatible
// GetInputDim は入力次元数を返す
GetInputDim() int
// GetOutputDim は出力次元数を返す
GetOutputDim() int
// RequiresFit は学習が必要かどうか
RequiresFit() bool
}
PipelineCompatible はパイプラインで使用可能なモデルのインターフェース
type Predictor ¶
type Predictor interface {
// Predict performs predictions on input data
Predict(X mat.Matrix) (mat.Matrix, error)
}
Predictor is an interface for predictive models
type RegressorMixin ¶ added in v0.3.0
type RegressorMixin interface {
Estimator
// Score は決定係数R²を計算
Score(X, y mat.Matrix) (float64, error)
}
RegressorMixin は回帰器のMixinインターフェース
type RegressorWithPartialFit ¶ added in v0.4.0
type RegressorWithPartialFit interface {
Regressor
IncrementalLearner
}
RegressorWithPartialFit combines interfaces for online regression models.
type SKLearnCompatible ¶ added in v0.3.0
type SKLearnCompatible interface {
// GetParams はモデルのハイパーパラメータを取得
GetParams(deep bool) map[string]interface{}
// SetParams はモデルのハイパーパラメータを設定
SetParams(params map[string]interface{}) error
// Clone はモデルの新しいインスタンスを同じパラメータで作成
Clone() SKLearnCompatible
}
SKLearnCompatible はscikit-learn互換のインターフェース
type SKLearnLinearRegressionParams ¶
type SKLearnLinearRegressionParams struct {
Coefficients []float64 `json:"coefficients"` // 係数(重み)
Intercept float64 `json:"intercept"` // 切片
NFeatures int `json:"n_features"` // 特徴量の数
}
SKLearnLinearRegressionParams は線形回帰モデルのパラメータ
func LoadLinearRegressionParams ¶
func LoadLinearRegressionParams(model *SKLearnModel) (*SKLearnLinearRegressionParams, error)
LoadLinearRegressionParams はLinearRegressionのパラメータを読み込む
パラメータ:
- model: SKLearnModelインスタンス
戻り値:
- *SKLearnLinearRegressionParams: パラメータ
- error: パース失敗時のエラー
type SKLearnModel ¶
type SKLearnModel struct {
ModelSpec SKLearnModelSpec `json:"model_spec"`
Params json.RawMessage `json:"params"`
}
SKLearnModel はscikit-learnからエクスポートされたモデル
func LoadSKLearnModelFromFile ¶
func LoadSKLearnModelFromFile(filename string) (*SKLearnModel, error)
LoadSKLearnModelFromFile はファイルからscikit-learnモデルを読み込む
パラメータ:
- filename: JSONファイルのパス
戻り値:
- *SKLearnModel: 読み込まれたモデル
- error: 読み込みエラー
使用例:
model, err := model.LoadSKLearnModelFromFile("sklearn_model.json")
if err != nil {
log.Fatal(err)
}
func LoadSKLearnModelFromReader ¶
func LoadSKLearnModelFromReader(r io.Reader) (*SKLearnModel, error)
LoadSKLearnModelFromReader はReaderからscikit-learnモデルを読み込む
パラメータ:
- r: JSONデータを含むReader
戻り値:
- *SKLearnModel: 読み込まれたモデル
- error: 読み込みエラー
type SKLearnModelSpec ¶
type SKLearnModelSpec struct {
Name string `json:"name"` // モデル名 (e.g., "LinearRegression")
FormatVersion string `json:"format_version"` // フォーマットバージョン
SKLearnVersion string `json:"sklearn_version,omitempty"` // scikit-learnのバージョン
}
SKLearnModelSpec はscikit-learnモデルのメタデータ
type Scorer ¶ added in v0.4.0
type Scorer interface {
// Score returns the coefficient of determination R^2 of the prediction.
Score(X mat.Matrix, y mat.Matrix) (float64, error)
}
Scorer is the interface for models that can compute a score.
type StateManager ¶ added in v0.4.0
type StateManager struct {
Fitted bool // Public for gob encoding
// Optional metadata - Public for gob encoding
NFeatures int
NSamples int
// contains filtered or unexported fields
}
StateManager manages the fitted state of a model in a thread-safe manner. It replaces the BaseEstimator embedding pattern with composition.
func NewStateManager ¶ added in v0.4.0
func NewStateManager() *StateManager
NewStateManager creates a new StateManager instance.
func (*StateManager) GetDimensions ¶ added in v0.4.0
func (s *StateManager) GetDimensions() (nFeatures, nSamples int)
GetDimensions returns the number of features and samples seen during fitting.
func (*StateManager) GetState ¶ added in v0.4.0
func (s *StateManager) GetState() ModelState
GetState returns the current state as a ModelState struct.
func (*StateManager) IsFitted ¶ added in v0.4.0
func (s *StateManager) IsFitted() bool
IsFitted returns whether the model has been fitted.
func (*StateManager) RequireFitted ¶ added in v0.4.0
func (s *StateManager) RequireFitted() error
RequireFitted returns an error if the model has not been fitted.
func (*StateManager) Reset ¶ added in v0.4.0
func (s *StateManager) Reset()
Reset resets the fitted state.
func (*StateManager) SetDimensions ¶ added in v0.4.0
func (s *StateManager) SetDimensions(nFeatures, nSamples int)
SetDimensions sets the number of features and samples seen during fitting.
func (*StateManager) SetFitted ¶ added in v0.4.0
func (s *StateManager) SetFitted()
SetFitted marks the model as fitted.
func (*StateManager) SetState ¶ added in v0.4.0
func (s *StateManager) SetState(state ModelState)
SetState sets the state from a ModelState struct.
func (*StateManager) WithState ¶ added in v0.4.0
func (s *StateManager) WithState(fn func() error) error
WithState is a helper function that executes a function with the state locked for reading.
func (*StateManager) WithStateMut ¶ added in v0.4.0
func (s *StateManager) WithStateMut(fn func() error) error
WithStateMut is a helper function that executes a function with the state locked for writing.
type StreamingEstimator ¶
type StreamingEstimator interface {
IncrementalEstimator
// FitStream trains the model from a data stream
// Continues learning until the context is canceled or the channel is closed
FitStream(ctx context.Context, dataChan <-chan *Batch) error
// PredictStream performs real-time predictions on input stream
// Output channel is closed when input channel is closed
PredictStream(ctx context.Context, inputChan <-chan mat.Matrix) <-chan mat.Matrix
// FitPredictStream performs learning and prediction simultaneously
// Returns predictions while training on new data (test-then-train approach)
FitPredictStream(ctx context.Context, dataChan <-chan *Batch) <-chan mat.Matrix
}
StreamingEstimator provides channel-based streaming learning interface
type StreamingMetrics ¶
type StreamingMetrics interface {
OnlineMetrics
// GetThroughput returns current throughput (samples/second)
GetThroughput() float64
// GetProcessedSamples returns total number of processed samples
GetProcessedSamples() int64
// GetAverageLatency returns average latency in milliseconds
GetAverageLatency() float64
// GetMemoryUsage returns current memory usage in bytes
GetMemoryUsage() int64
}
StreamingMetrics provides metrics during streaming learning
type Transformer ¶
type Transformer interface {
// Fit learns parameters necessary for transformation
Fit(X mat.Matrix) error
// Transform transforms data
Transform(X mat.Matrix) (mat.Matrix, error)
// FitTransform executes Fit and Transform simultaneously
FitTransform(X mat.Matrix) (mat.Matrix, error)
}
Transformer is an interface for data transformation
type TransformerMixin ¶ added in v0.3.0
type TransformerMixin interface {
// Transform はデータを変換
Transform(X mat.Matrix) (mat.Matrix, error)
// FitTransform は学習と変換を同時に実行
FitTransform(X mat.Matrix) (mat.Matrix, error)
}
TransformerMixin は変換器のMixinインターフェース
type WeightExporter ¶ added in v0.3.0
type WeightExporter interface {
// ExportWeights はモデルの重みをエクスポート
ExportWeights() (*ModelWeights, error)
// ImportWeights はモデルの重みをインポート
ImportWeights(weights *ModelWeights) error
// GetWeightHash は重みのハッシュ値を計算(検証用)
GetWeightHash() string
}
WeightExporter は重みをエクスポート可能なモデルのインターフェース