Documentation
¶
Overview ¶
Package model provides core abstractions and interfaces for machine learning models.
This package defines the fundamental building blocks for machine learning estimators in the SciGo library, including:
- BaseEstimator: Core estimator with state management and serialization support
- Model persistence: Save and load trained models using Go's encoding/gob
- scikit-learn compatibility: Import/export models from Python scikit-learn
- Streaming interfaces: Support for online learning and incremental training
The BaseEstimator provides a consistent foundation for all ML algorithms with:
- Fitted state tracking to prevent usage of untrained models
- Serialization support for model persistence
- Thread-safe state management
- Integration with preprocessing pipelines
Example usage:
type MyModel struct {
model.BaseEstimator
// model-specific fields
}
func (m *MyModel) Fit(X, y mat.Matrix) error {
// training logic
m.SetFitted() // mark as trained
return nil
}
All models in SciGo embed BaseEstimator to ensure consistent behavior across the entire machine learning pipeline.
Index ¶
- func ExportSKLearnModel(modelName string, params interface{}, w io.Writer) error
- func LoadModel(model interface{}, filename string) error
- func LoadModelFromReader(model interface{}, r io.Reader) error
- func SaveModel(model interface{}, filename string) error
- func SaveModelToWriter(model interface{}, w io.Writer) error
- type AdaptiveLearning
- type BaseEstimator
- func (e *BaseEstimator) Clone() *BaseEstimator
- func (e *BaseEstimator) ExportWeights() (*ModelWeights, error)
- func (e *BaseEstimator) GetLogger() interface{}
- func (e *BaseEstimator) GetParams(deep bool) map[string]interface{}
- func (e *BaseEstimator) GetWeightHash() string
- func (e *BaseEstimator) ImportWeights(weights *ModelWeights) error
- func (e *BaseEstimator) IsFitted() bool
- func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})
- func (e *BaseEstimator) LogError(msg string, fields ...interface{})
- func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})
- func (e *BaseEstimator) Reset()
- func (e *BaseEstimator) SetFitted()
- func (e *BaseEstimator) SetLogger(logger interface{})
- func (e *BaseEstimator) SetParams(params map[string]interface{}) error
- type Batch
- type BufferedStreaming
- type ClassifierMixin
- type ClusterMixin
- type CrossValidatable
- type Estimator
- type EstimatorState
- type Fitter
- type IncrementalEstimator
- type InverseTransformerMixin
- type LinearModel
- type ModelValidation
- type ModelWeights
- type OnlineMetrics
- type ParallelStreaming
- type PartialFitMixin
- type PipelineCompatible
- type Predictor
- type RegressorMixin
- type SKLearnCompatible
- type SKLearnLinearRegressionParams
- type SKLearnModel
- type SKLearnModelSpec
- type StreamingEstimator
- type StreamingMetrics
- type Transformer
- type TransformerMixin
- type WeightExporter
Examples ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ExportSKLearnModel ¶
ExportSKLearnModel はモデルをscikit-learn互換のJSON形式でエクスポート
パラメータ:
- modelName: モデル名
- params: モデルパラメータ
- w: 出力先Writer
戻り値:
- error: エクスポート失敗時のエラー
func LoadModel ¶
LoadModel はファイルからモデルを読み込む
パラメータ:
- model: 読み込み先のモデル(BaseEstimatorを埋め込んだ構造体のポインタ)
- filename: 読み込み元のファイルパス
戻り値:
- error: 読み込みに失敗した場合のエラー
使用例:
var reg linear.Regression err := model.LoadModel(®, "model.gob")
func LoadModelFromReader ¶
LoadModelFromReader はio.Readerからモデルを読み込む
パラメータ:
- model: 読み込み先のモデル(ポインタ)
- r: 読み込み元のReader
戻り値:
- error: 読み込みに失敗した場合のエラー
func SaveModel ¶
SaveModel はモデルをファイルに保存する
パラメータ:
- model: 保存するモデル(BaseEstimatorを埋め込んだ構造体)
- filename: 保存先のファイルパス
戻り値:
- error: 保存に失敗した場合のエラー
使用例:
var reg linear.Regression // ... モデルの学習 ... err := model.SaveModel(®, "model.gob")
func SaveModelToWriter ¶
SaveModelToWriter はモデルをio.Writerに保存する
パラメータ:
- model: 保存するモデル
- w: 保存先のWriter
戻り値:
- error: 保存に失敗した場合のエラー
Types ¶
type AdaptiveLearning ¶
type AdaptiveLearning interface {
// GetLearningRate は現在の学習率を返す
GetLearningRate() float64
// SetLearningRate は学習率を設定する
SetLearningRate(lr float64)
// GetLearningRateSchedule は学習率スケジュールを返す
// "constant", "optimal", "invscaling", "adaptive" など
GetLearningRateSchedule() string
// SetLearningRateSchedule は学習率スケジュールを設定する
SetLearningRateSchedule(schedule string)
}
AdaptiveLearning は学習率を動的に調整できるモデルのインターフェース
type BaseEstimator ¶
type BaseEstimator struct {
// State はモデルの学習状態を保持します。gobでエンコードするために公開されています。
State EstimatorState
// ModelType はモデルの種類を識別
ModelType string
// Version はモデルのバージョン
Version string
// contains filtered or unexported fields
}
BaseEstimator は全てのモデルの基底となる構造体
Example ¶
ExampleBaseEstimator demonstrates BaseEstimator state management
package main
import (
"fmt"
"github.com/YuminosukeSato/scigo/core/model"
)
func main() {
// Create a BaseEstimator (typically embedded in actual models)
estimator := &model.BaseEstimator{}
// Check initial state
fmt.Printf("Initially fitted: %t\n", estimator.IsFitted())
// Mark as fitted
estimator.SetFitted()
fmt.Printf("After SetFitted: %t\n", estimator.IsFitted())
// Reset to unfitted state
estimator.Reset()
fmt.Printf("After Reset: %t\n", estimator.IsFitted())
}
Output: Initially fitted: false After SetFitted: true After Reset: false
Example (WorkflowPattern) ¶
ExampleBaseEstimator_workflowPattern demonstrates typical usage pattern
package main
import (
"fmt"
"github.com/YuminosukeSato/scigo/core/model"
)
func main() {
// This example shows how BaseEstimator is typically used in models
type MyModel struct {
model.BaseEstimator
// model-specific fields would go here
}
myModel := &MyModel{}
// Check if model needs training
if !myModel.IsFitted() {
fmt.Println("Model needs training")
// Simulate training process
// ... training logic would go here ...
// Mark as fitted after successful training
myModel.SetFitted()
fmt.Println("Model trained successfully")
}
// Now model is ready for use
if myModel.IsFitted() {
fmt.Println("Model is ready for predictions")
}
}
Output: Model needs training Model trained successfully Model is ready for predictions
func (*BaseEstimator) Clone ¶ added in v0.3.0
func (e *BaseEstimator) Clone() *BaseEstimator
Clone はモデルの新しいインスタンスを作成(基本実装)
func (*BaseEstimator) ExportWeights ¶ added in v0.3.0
func (e *BaseEstimator) ExportWeights() (*ModelWeights, error)
ExportWeights はモデルの重みをエクスポート(基本実装)
func (*BaseEstimator) GetLogger ¶
func (e *BaseEstimator) GetLogger() interface{}
GetLogger returns the logger for this estimator. Returns nil if no logger has been set.
Returns:
- interface{}: The logger instance, should be type-asserted to log.Logger
Example:
if logger := model.GetLogger(); logger != nil {
if l, ok := logger.(log.Logger); ok {
l.Info("Operation completed")
}
}
func (*BaseEstimator) GetParams ¶ added in v0.3.0
func (e *BaseEstimator) GetParams(deep bool) map[string]interface{}
GetParams はモデルのハイパーパラメータを取得(scikit-learn互換)
func (*BaseEstimator) GetWeightHash ¶ added in v0.3.0
func (e *BaseEstimator) GetWeightHash() string
GetWeightHash は重みのハッシュ値を計算(検証用)
func (*BaseEstimator) ImportWeights ¶ added in v0.3.0
func (e *BaseEstimator) ImportWeights(weights *ModelWeights) error
ImportWeights はモデルの重みをインポート(基本実装)
func (*BaseEstimator) IsFitted ¶
func (e *BaseEstimator) IsFitted() bool
IsFitted returns whether the model has been fitted with training data.
This method checks the internal state to determine if the estimator has been trained and is ready for prediction or transformation. All models must be fitted before they can be used for predictions.
Returns:
- bool: true if the model is fitted, false otherwise
Example:
if !model.IsFitted() {
err := model.Fit(X, y)
if err != nil {
log.Fatal(err)
}
}
predictions, err := model.Predict(X_test)
func (*BaseEstimator) LogDebug ¶
func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})
LogDebug logs a debug-level message if a logger is configured. This is a convenience method for debug logging in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs
func (*BaseEstimator) LogError ¶
func (e *BaseEstimator) LogError(msg string, fields ...interface{})
LogError logs an error-level message if a logger is configured. This is a convenience method for error logging in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs If the first field is an error, it will be handled specially
func (*BaseEstimator) LogInfo ¶
func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})
LogInfo logs an info-level message if a logger is configured. This is a convenience method to avoid repetitive nil checks in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs
func (*BaseEstimator) Reset ¶
func (e *BaseEstimator) Reset()
Reset returns the estimator to its initial untrained state.
This method clears the fitted state, effectively making the model untrained. Useful for reusing model instances with different data or resetting after errors. After reset, the model must be fitted again before use.
Example:
model.Reset()
err := model.Fit(newTrainingData, newLabels)
if err != nil {
log.Fatal(err)
}
func (*BaseEstimator) SetFitted ¶
func (e *BaseEstimator) SetFitted()
SetFitted marks the estimator as fitted (trained).
This method is called internally by model implementations after successful training to indicate that the model is ready for predictions or transformations. Should only be called by model implementations, not by end users.
Example (within a model's Fit method):
func (m *MyModel) Fit(X, y mat.Matrix) error {
// ... training logic ...
m.SetFitted() // Mark as trained
return nil
}
func (*BaseEstimator) SetLogger ¶
func (e *BaseEstimator) SetLogger(logger interface{})
SetLogger sets the logger for this estimator. This method is typically called during model initialization to provide structured logging capabilities for ML operations.
Parameters:
- logger: Any logger implementation (typically log.Logger interface)
Example:
import "github.com/YuminosukeSato/scigo/pkg/log"
model.SetLogger(log.GetLoggerWithName("LinearRegression"))
func (*BaseEstimator) SetParams ¶ added in v0.3.0
func (e *BaseEstimator) SetParams(params map[string]interface{}) error
SetParams はモデルのハイパーパラメータを設定(scikit-learn互換)
type BufferedStreaming ¶
type BufferedStreaming interface {
// SetBufferSize はストリーミングバッファのサイズを設定
SetBufferSize(size int)
// GetBufferSize は現在のバッファサイズを返す
GetBufferSize() int
// FlushBuffer はバッファを強制的にフラッシュ
FlushBuffer() error
}
BufferedStreaming はバッファリング機能を持つストリーミングインターフェース
type ClassifierMixin ¶ added in v0.3.0
type ClassifierMixin interface {
Estimator
// PredictProba は各クラスの確率を予測
PredictProba(X mat.Matrix) (mat.Matrix, error)
// PredictLogProba は各クラスの対数確率を予測
PredictLogProba(X mat.Matrix) (mat.Matrix, error)
// DecisionFunction は決定関数の値を計算
DecisionFunction(X mat.Matrix) (mat.Matrix, error)
// Classes は学習されたクラスラベルを返す
Classes() []interface{}
// NClasses は学習されたクラス数を返す
NClasses() int
}
ClassifierMixin は分類器のMixinインターフェース
type ClusterMixin ¶ added in v0.3.0
type ClusterMixin interface {
Fitter
// FitPredict は学習と予測を同時に実行
FitPredict(X mat.Matrix) ([]int, error)
// PredictCluster は新しいデータのクラスタを予測
PredictCluster(X mat.Matrix) ([]int, error)
// NClusters はクラスタ数を返す
NClusters() int
}
ClusterMixin はクラスタリングのMixinインターフェース
type CrossValidatable ¶ added in v0.3.0
type CrossValidatable interface {
// GetCVSplits はクロスバリデーションの分割数を返す
GetCVSplits() int
// SetCVSplits はクロスバリデーションの分割数を設定
SetCVSplits(n int)
// GetCVScores はクロスバリデーションのスコアを返す
GetCVScores() []float64
}
CrossValidatable はクロスバリデーション可能なモデルのインターフェース
type EstimatorState ¶
type EstimatorState int
EstimatorState はモデルの学習状態を表す
const ( // NotFitted はモデルが未学習の状態 NotFitted EstimatorState = iota // Fitted はモデルが学習済みの状態 Fitted )
type IncrementalEstimator ¶
type IncrementalEstimator interface {
Estimator
// PartialFit はミニバッチでモデルを逐次的に学習させる
// classes は分類問題の場合に全クラスラベルを指定(最初の呼び出し時のみ必須)
// 回帰問題の場合は nil を渡す
PartialFit(X, y mat.Matrix, classes []int) error
// NIterations は実行された学習イテレーション数を返す
NIterations() int
// WarmStart が有効かどうかを返す
// true の場合、Fit 呼び出し時に既存のパラメータから学習を継続
IsWarmStart() bool
// SetWarmStart はウォームスタートの有効/無効を設定
SetWarmStart(warmStart bool)
}
IncrementalEstimator はオンライン学習(逐次学習)可能なモデルのインターフェース scikit-learnのpartial_fit APIと互換性を持つ
type InverseTransformerMixin ¶ added in v0.3.0
type InverseTransformerMixin interface {
TransformerMixin
// InverseTransform は変換を逆方向に適用
InverseTransform(X mat.Matrix) (mat.Matrix, error)
}
InverseTransformerMixin は逆変換可能な変換器のインターフェース
type LinearModel ¶
type LinearModel interface {
// Weights は学習された重み(係数)を返す
Weights() []float64
// Intercept は学習された切片を返す
Intercept() float64
// Score はモデルの決定係数(R²)を計算する
Score(X, y mat.Matrix) (float64, error)
}
LinearModel は線形モデルのインターフェース
type ModelValidation ¶ added in v0.3.0
type ModelValidation struct {
// ValidateInput は入力データの検証
ValidateInput func(X mat.Matrix) error
// ValidateOutput は出力データの検証
ValidateOutput func(y mat.Matrix) error
// ValidateWeights は重みの検証
ValidateWeights func(weights *ModelWeights) error
}
ModelValidation はモデルの検証機能を提供
type ModelWeights ¶ added in v0.3.0
type ModelWeights struct {
// ModelType はモデルの種類(LinearRegression, SGDRegressor等)
ModelType string `json:"model_type"`
// Version はモデルのバージョン(互換性チェック用)
Version string `json:"version"`
// Coefficients は重み係数
Coefficients []float64 `json:"coefficients"`
// Intercept は切片
Intercept float64 `json:"intercept"`
// Features は特徴量の名前(オプション)
Features []string `json:"features,omitempty"`
// Hyperparameters はモデルのハイパーパラメータ
Hyperparameters map[string]interface{} `json:"hyperparameters"`
// Metadata は追加のメタデータ(学習時の統計等)
Metadata map[string]interface{} `json:"metadata,omitempty"`
// IsFitted はモデルが学習済みかどうか
IsFitted bool `json:"is_fitted"`
}
ModelWeights はモデルの重みを表す構造体(シリアライゼーション用)
func (*ModelWeights) Clone ¶ added in v0.3.0
func (mw *ModelWeights) Clone() *ModelWeights
Clone はModelWeightsのディープコピーを作成
func (*ModelWeights) FromJSON ¶ added in v0.3.0
func (mw *ModelWeights) FromJSON(data []byte) error
FromJSON はJSON形式からModelWeightsをデシリアライズ
func (*ModelWeights) ToJSON ¶ added in v0.3.0
func (mw *ModelWeights) ToJSON() ([]byte, error)
ToJSON はModelWeightsをJSON形式にシリアライズ
func (*ModelWeights) Validate ¶ added in v0.3.0
func (mw *ModelWeights) Validate() error
Validate はModelWeightsの妥当性を検証
type OnlineMetrics ¶
type OnlineMetrics interface {
// GetLoss は現在の損失値を返す
GetLoss() float64
// GetLossHistory は損失値の履歴を返す
GetLossHistory() []float64
// GetConverged は収束したかどうかを返す
GetConverged() bool
}
OnlineMetrics はオンライン学習中のメトリクスを追跡するインターフェース
type ParallelStreaming ¶
type ParallelStreaming interface {
// SetWorkers はワーカー数を設定
SetWorkers(n int)
// GetWorkers は現在のワーカー数を返す
GetWorkers() int
// SetBatchParallelism はバッチ内並列処理の有効/無効を設定
SetBatchParallelism(enabled bool)
}
ParallelStreaming は並列ストリーミング処理のインターフェース
type PartialFitMixin ¶ added in v0.3.0
type PartialFitMixin interface {
// PartialFit はミニバッチで逐次学習
PartialFit(X, y mat.Matrix, classes []int) error
// NIterations は学習イテレーション数を返す
NIterations() int
// IsWarmStart はウォームスタートが有効かどうか
IsWarmStart() bool
// SetWarmStart はウォームスタートの有効/無効を設定
SetWarmStart(warmStart bool)
}
PartialFitMixin は逐次学習可能なモデルのインターフェース
type PipelineCompatible ¶ added in v0.3.0
type PipelineCompatible interface {
SKLearnCompatible
// GetInputDim は入力次元数を返す
GetInputDim() int
// GetOutputDim は出力次元数を返す
GetOutputDim() int
// RequiresFit は学習が必要かどうか
RequiresFit() bool
}
PipelineCompatible はパイプラインで使用可能なモデルのインターフェース
type RegressorMixin ¶ added in v0.3.0
type RegressorMixin interface {
Estimator
// Score は決定係数R²を計算
Score(X, y mat.Matrix) (float64, error)
}
RegressorMixin は回帰器のMixinインターフェース
type SKLearnCompatible ¶ added in v0.3.0
type SKLearnCompatible interface {
// GetParams はモデルのハイパーパラメータを取得
GetParams(deep bool) map[string]interface{}
// SetParams はモデルのハイパーパラメータを設定
SetParams(params map[string]interface{}) error
// Clone はモデルの新しいインスタンスを同じパラメータで作成
Clone() SKLearnCompatible
}
SKLearnCompatible はscikit-learn互換のインターフェース
type SKLearnLinearRegressionParams ¶
type SKLearnLinearRegressionParams struct {
Coefficients []float64 `json:"coefficients"` // 係数(重み)
Intercept float64 `json:"intercept"` // 切片
NFeatures int `json:"n_features"` // 特徴量の数
}
SKLearnLinearRegressionParams は線形回帰モデルのパラメータ
func LoadLinearRegressionParams ¶
func LoadLinearRegressionParams(model *SKLearnModel) (*SKLearnLinearRegressionParams, error)
LoadLinearRegressionParams はLinearRegressionのパラメータを読み込む
パラメータ:
- model: SKLearnModelインスタンス
戻り値:
- *SKLearnLinearRegressionParams: パラメータ
- error: パース失敗時のエラー
type SKLearnModel ¶
type SKLearnModel struct {
ModelSpec SKLearnModelSpec `json:"model_spec"`
Params json.RawMessage `json:"params"`
}
SKLearnModel はscikit-learnからエクスポートされたモデル
func LoadSKLearnModelFromFile ¶
func LoadSKLearnModelFromFile(filename string) (*SKLearnModel, error)
LoadSKLearnModelFromFile はファイルからscikit-learnモデルを読み込む
パラメータ:
- filename: JSONファイルのパス
戻り値:
- *SKLearnModel: 読み込まれたモデル
- error: 読み込みエラー
使用例:
model, err := model.LoadSKLearnModelFromFile("sklearn_model.json")
if err != nil {
log.Fatal(err)
}
func LoadSKLearnModelFromReader ¶
func LoadSKLearnModelFromReader(r io.Reader) (*SKLearnModel, error)
LoadSKLearnModelFromReader はReaderからscikit-learnモデルを読み込む
パラメータ:
- r: JSONデータを含むReader
戻り値:
- *SKLearnModel: 読み込まれたモデル
- error: 読み込みエラー
type SKLearnModelSpec ¶
type SKLearnModelSpec struct {
Name string `json:"name"` // モデル名 (e.g., "LinearRegression")
FormatVersion string `json:"format_version"` // フォーマットバージョン
SKLearnVersion string `json:"sklearn_version,omitempty"` // scikit-learnのバージョン
}
SKLearnModelSpec はscikit-learnモデルのメタデータ
type StreamingEstimator ¶
type StreamingEstimator interface {
IncrementalEstimator
// FitStream はデータストリームからモデルを学習する
// コンテキストがキャンセルされるまで、またはチャネルがクローズされるまで学習を継続
FitStream(ctx context.Context, dataChan <-chan *Batch) error
// PredictStream は入力ストリームに対してリアルタイム予測を行う
// 入力チャネルがクローズされると出力チャネルもクローズされる
PredictStream(ctx context.Context, inputChan <-chan mat.Matrix) <-chan mat.Matrix
// FitPredictStream は学習と予測を同時に行う
// 新しいデータで学習しながら、同時に予測も返す(test-then-train方式)
FitPredictStream(ctx context.Context, dataChan <-chan *Batch) <-chan mat.Matrix
}
StreamingEstimator はチャネルベースのストリーミング学習を提供するインターフェース
type StreamingMetrics ¶
type StreamingMetrics interface {
OnlineMetrics
// GetThroughput は現在のスループット(サンプル/秒)を返す
GetThroughput() float64
// GetProcessedSamples は処理されたサンプル総数を返す
GetProcessedSamples() int64
// GetAverageLatency は平均レイテンシ(ミリ秒)を返す
GetAverageLatency() float64
// GetMemoryUsage は現在のメモリ使用量(バイト)を返す
GetMemoryUsage() int64
}
StreamingMetrics はストリーミング学習中のメトリクスを提供
type Transformer ¶
type Transformer interface {
// Fit は変換に必要なパラメータを学習する
Fit(X mat.Matrix) error
// Transform はデータを変換する
Transform(X mat.Matrix) (mat.Matrix, error)
// FitTransform はFitとTransformを同時に実行する
FitTransform(X mat.Matrix) (mat.Matrix, error)
}
Transformer はデータ変換のインターフェース
type TransformerMixin ¶ added in v0.3.0
type TransformerMixin interface {
// Transform はデータを変換
Transform(X mat.Matrix) (mat.Matrix, error)
// FitTransform は学習と変換を同時に実行
FitTransform(X mat.Matrix) (mat.Matrix, error)
}
TransformerMixin は変換器のMixinインターフェース
type WeightExporter ¶ added in v0.3.0
type WeightExporter interface {
// ExportWeights はモデルの重みをエクスポート
ExportWeights() (*ModelWeights, error)
// ImportWeights はモデルの重みをインポート
ImportWeights(weights *ModelWeights) error
// GetWeightHash は重みのハッシュ値を計算(検証用)
GetWeightHash() string
}
WeightExporter は重みをエクスポート可能なモデルのインターフェース