model

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 7, 2025 License: MIT Imports: 10 Imported by: 0

Documentation

Overview

Package model provides core abstractions and interfaces for machine learning models.

This package defines the fundamental building blocks for machine learning estimators in the SciGo library, including:

  • BaseEstimator: Core estimator with state management and serialization support
  • Model persistence: Save and load trained models using Go's encoding/gob
  • scikit-learn compatibility: Import/export models from Python scikit-learn
  • Streaming interfaces: Support for online learning and incremental training

The BaseEstimator provides a consistent foundation for all ML algorithms with:

  • Fitted state tracking to prevent usage of untrained models
  • Serialization support for model persistence
  • Thread-safe state management
  • Integration with preprocessing pipelines

Example usage:

type MyModel struct {
	model.BaseEstimator
	// model-specific fields
}

func (m *MyModel) Fit(X, y mat.Matrix) error {
	// training logic
	m.SetFitted() // mark as trained
	return nil
}

All models in SciGo embed BaseEstimator to ensure consistent behavior across the entire machine learning pipeline.

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func ExportSKLearnModel

func ExportSKLearnModel(modelName string, params interface{}, w io.Writer) error

ExportSKLearnModel はモデルをscikit-learn互換のJSON形式でエクスポート

パラメータ:

  • modelName: モデル名
  • params: モデルパラメータ
  • w: 出力先Writer

戻り値:

  • error: エクスポート失敗時のエラー

func LoadModel

func LoadModel(model interface{}, filename string) error

LoadModel はファイルからモデルを読み込む

パラメータ:

  • model: 読み込み先のモデル(BaseEstimatorを埋め込んだ構造体のポインタ)
  • filename: 読み込み元のファイルパス

戻り値:

  • error: 読み込みに失敗した場合のエラー

使用例:

var reg linear.Regression
err := model.LoadModel(&reg, "model.gob")

func LoadModelFromReader

func LoadModelFromReader(model interface{}, r io.Reader) error

LoadModelFromReader はio.Readerからモデルを読み込む

パラメータ:

  • model: 読み込み先のモデル(ポインタ)
  • r: 読み込み元のReader

戻り値:

  • error: 読み込みに失敗した場合のエラー

func SaveModel

func SaveModel(model interface{}, filename string) error

SaveModel はモデルをファイルに保存する

パラメータ:

  • model: 保存するモデル(BaseEstimatorを埋め込んだ構造体)
  • filename: 保存先のファイルパス

戻り値:

  • error: 保存に失敗した場合のエラー

使用例:

var reg linear.Regression
// ... モデルの学習 ...
err := model.SaveModel(&reg, "model.gob")

func SaveModelToWriter

func SaveModelToWriter(model interface{}, w io.Writer) error

SaveModelToWriter はモデルをio.Writerに保存する

パラメータ:

  • model: 保存するモデル
  • w: 保存先のWriter

戻り値:

  • error: 保存に失敗した場合のエラー

Types

type AdaptiveLearning

type AdaptiveLearning interface {
	// GetLearningRate は現在の学習率を返す
	GetLearningRate() float64

	// SetLearningRate は学習率を設定する
	SetLearningRate(lr float64)

	// GetLearningRateSchedule は学習率スケジュールを返す
	// "constant", "optimal", "invscaling", "adaptive" など
	GetLearningRateSchedule() string

	// SetLearningRateSchedule は学習率スケジュールを設定する
	SetLearningRateSchedule(schedule string)
}

AdaptiveLearning は学習率を動的に調整できるモデルのインターフェース

type BaseEstimator

type BaseEstimator struct {
	// State はモデルの学習状態を保持します。gobでエンコードするために公開されています。
	State EstimatorState

	// ModelType はモデルの種類を識別
	ModelType string

	// Version はモデルのバージョン
	Version string
	// contains filtered or unexported fields
}

BaseEstimator は全てのモデルの基底となる構造体

Example

ExampleBaseEstimator demonstrates BaseEstimator state management

package main

import (
	"fmt"

	"github.com/YuminosukeSato/scigo/core/model"
)

func main() {
	// Create a BaseEstimator (typically embedded in actual models)
	estimator := &model.BaseEstimator{}

	// Check initial state
	fmt.Printf("Initially fitted: %t\n", estimator.IsFitted())

	// Mark as fitted
	estimator.SetFitted()
	fmt.Printf("After SetFitted: %t\n", estimator.IsFitted())

	// Reset to unfitted state
	estimator.Reset()
	fmt.Printf("After Reset: %t\n", estimator.IsFitted())

}
Output:

Initially fitted: false
After SetFitted: true
After Reset: false
Example (WorkflowPattern)

ExampleBaseEstimator_workflowPattern demonstrates typical usage pattern

package main

import (
	"fmt"

	"github.com/YuminosukeSato/scigo/core/model"
)

func main() {
	// This example shows how BaseEstimator is typically used in models
	type MyModel struct {
		model.BaseEstimator
		// model-specific fields would go here
	}

	myModel := &MyModel{}

	// Check if model needs training
	if !myModel.IsFitted() {
		fmt.Println("Model needs training")

		// Simulate training process
		// ... training logic would go here ...

		// Mark as fitted after successful training
		myModel.SetFitted()
		fmt.Println("Model trained successfully")
	}

	// Now model is ready for use
	if myModel.IsFitted() {
		fmt.Println("Model is ready for predictions")
	}

}
Output:

Model needs training
Model trained successfully
Model is ready for predictions

func (*BaseEstimator) Clone added in v0.3.0

func (e *BaseEstimator) Clone() *BaseEstimator

Clone はモデルの新しいインスタンスを作成(基本実装)

func (*BaseEstimator) ExportWeights added in v0.3.0

func (e *BaseEstimator) ExportWeights() (*ModelWeights, error)

ExportWeights はモデルの重みをエクスポート(基本実装)

func (*BaseEstimator) GetLogger

func (e *BaseEstimator) GetLogger() interface{}

GetLogger returns the logger for this estimator. Returns nil if no logger has been set.

Returns:

  • interface{}: The logger instance, should be type-asserted to log.Logger

Example:

if logger := model.GetLogger(); logger != nil {
    if l, ok := logger.(log.Logger); ok {
        l.Info("Operation completed")
    }
}

func (*BaseEstimator) GetParams added in v0.3.0

func (e *BaseEstimator) GetParams(deep bool) map[string]interface{}

GetParams はモデルのハイパーパラメータを取得(scikit-learn互換)

func (*BaseEstimator) GetWeightHash added in v0.3.0

func (e *BaseEstimator) GetWeightHash() string

GetWeightHash は重みのハッシュ値を計算(検証用)

func (*BaseEstimator) ImportWeights added in v0.3.0

func (e *BaseEstimator) ImportWeights(weights *ModelWeights) error

ImportWeights はモデルの重みをインポート(基本実装)

func (*BaseEstimator) IsFitted

func (e *BaseEstimator) IsFitted() bool

IsFitted returns whether the model has been fitted with training data.

This method checks the internal state to determine if the estimator has been trained and is ready for prediction or transformation. All models must be fitted before they can be used for predictions.

Returns:

  • bool: true if the model is fitted, false otherwise

Example:

if !model.IsFitted() {
    err := model.Fit(X, y)
    if err != nil {
        log.Fatal(err)
    }
}
predictions, err := model.Predict(X_test)

func (*BaseEstimator) LogDebug

func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})

LogDebug logs a debug-level message if a logger is configured. This is a convenience method for debug logging in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs

func (*BaseEstimator) LogError

func (e *BaseEstimator) LogError(msg string, fields ...interface{})

LogError logs an error-level message if a logger is configured. This is a convenience method for error logging in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs If the first field is an error, it will be handled specially

func (*BaseEstimator) LogInfo

func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})

LogInfo logs an info-level message if a logger is configured. This is a convenience method to avoid repetitive nil checks in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs

func (*BaseEstimator) Reset

func (e *BaseEstimator) Reset()

Reset returns the estimator to its initial untrained state.

This method clears the fitted state, effectively making the model untrained. Useful for reusing model instances with different data or resetting after errors. After reset, the model must be fitted again before use.

Example:

model.Reset()
err := model.Fit(newTrainingData, newLabels)
if err != nil {
    log.Fatal(err)
}

func (*BaseEstimator) SetFitted

func (e *BaseEstimator) SetFitted()

SetFitted marks the estimator as fitted (trained).

This method is called internally by model implementations after successful training to indicate that the model is ready for predictions or transformations. Should only be called by model implementations, not by end users.

Example (within a model's Fit method):

func (m *MyModel) Fit(X, y mat.Matrix) error {
    // ... training logic ...
    m.SetFitted() // Mark as trained
    return nil
}

func (*BaseEstimator) SetLogger

func (e *BaseEstimator) SetLogger(logger interface{})

SetLogger sets the logger for this estimator. This method is typically called during model initialization to provide structured logging capabilities for ML operations.

Parameters:

  • logger: Any logger implementation (typically log.Logger interface)

Example:

import "github.com/YuminosukeSato/scigo/pkg/log"
model.SetLogger(log.GetLoggerWithName("LinearRegression"))

func (*BaseEstimator) SetParams added in v0.3.0

func (e *BaseEstimator) SetParams(params map[string]interface{}) error

SetParams はモデルのハイパーパラメータを設定(scikit-learn互換)

type Batch

type Batch struct {
	X mat.Matrix // 特徴量行列
	Y mat.Matrix // ターゲット行列
}

Batch はストリーミング学習用のデータバッチを表す

type BufferedStreaming

type BufferedStreaming interface {
	// SetBufferSize はストリーミングバッファのサイズを設定
	SetBufferSize(size int)

	// GetBufferSize は現在のバッファサイズを返す
	GetBufferSize() int

	// FlushBuffer はバッファを強制的にフラッシュ
	FlushBuffer() error
}

BufferedStreaming はバッファリング機能を持つストリーミングインターフェース

type ClassifierMixin added in v0.3.0

type ClassifierMixin interface {
	Estimator

	// PredictProba は各クラスの確率を予測
	PredictProba(X mat.Matrix) (mat.Matrix, error)

	// PredictLogProba は各クラスの対数確率を予測
	PredictLogProba(X mat.Matrix) (mat.Matrix, error)

	// DecisionFunction は決定関数の値を計算
	DecisionFunction(X mat.Matrix) (mat.Matrix, error)

	// Classes は学習されたクラスラベルを返す
	Classes() []interface{}

	// NClasses は学習されたクラス数を返す
	NClasses() int
}

ClassifierMixin は分類器のMixinインターフェース

type ClusterMixin added in v0.3.0

type ClusterMixin interface {
	Fitter

	// FitPredict は学習と予測を同時に実行
	FitPredict(X mat.Matrix) ([]int, error)

	// PredictCluster は新しいデータのクラスタを予測
	PredictCluster(X mat.Matrix) ([]int, error)

	// NClusters はクラスタ数を返す
	NClusters() int
}

ClusterMixin はクラスタリングのMixinインターフェース

type CrossValidatable added in v0.3.0

type CrossValidatable interface {
	// GetCVSplits はクロスバリデーションの分割数を返す
	GetCVSplits() int

	// SetCVSplits はクロスバリデーションの分割数を設定
	SetCVSplits(n int)

	// GetCVScores はクロスバリデーションのスコアを返す
	GetCVScores() []float64
}

CrossValidatable はクロスバリデーション可能なモデルのインターフェース

type Estimator

type Estimator interface {
	Fitter
	Predictor
}

Estimator は学習と予測の両方が可能なモデルのインターフェース

type EstimatorState

type EstimatorState int

EstimatorState はモデルの学習状態を表す

const (
	// NotFitted はモデルが未学習の状態
	NotFitted EstimatorState = iota
	// Fitted はモデルが学習済みの状態
	Fitted
)

type Fitter

type Fitter interface {
	// Fit はモデルを訓練データで学習させる
	Fit(X, y mat.Matrix) error
}

Fitter は学習可能なモデルのインターフェース

type IncrementalEstimator

type IncrementalEstimator interface {
	Estimator

	// PartialFit はミニバッチでモデルを逐次的に学習させる
	// classes は分類問題の場合に全クラスラベルを指定(最初の呼び出し時のみ必須)
	// 回帰問題の場合は nil を渡す
	PartialFit(X, y mat.Matrix, classes []int) error

	// NIterations は実行された学習イテレーション数を返す
	NIterations() int

	// WarmStart が有効かどうかを返す
	// true の場合、Fit 呼び出し時に既存のパラメータから学習を継続
	IsWarmStart() bool

	// SetWarmStart はウォームスタートの有効/無効を設定
	SetWarmStart(warmStart bool)
}

IncrementalEstimator はオンライン学習(逐次学習)可能なモデルのインターフェース scikit-learnのpartial_fit APIと互換性を持つ

type InverseTransformerMixin added in v0.3.0

type InverseTransformerMixin interface {
	TransformerMixin

	// InverseTransform は変換を逆方向に適用
	InverseTransform(X mat.Matrix) (mat.Matrix, error)
}

InverseTransformerMixin は逆変換可能な変換器のインターフェース

type LinearModel

type LinearModel interface {
	// Weights は学習された重み(係数)を返す
	Weights() []float64
	// Intercept は学習された切片を返す
	Intercept() float64
	// Score はモデルの決定係数(R²)を計算する
	Score(X, y mat.Matrix) (float64, error)
}

LinearModel は線形モデルのインターフェース

type ModelValidation added in v0.3.0

type ModelValidation struct {
	// ValidateInput は入力データの検証
	ValidateInput func(X mat.Matrix) error

	// ValidateOutput は出力データの検証
	ValidateOutput func(y mat.Matrix) error

	// ValidateWeights は重みの検証
	ValidateWeights func(weights *ModelWeights) error
}

ModelValidation はモデルの検証機能を提供

type ModelWeights added in v0.3.0

type ModelWeights struct {
	// ModelType はモデルの種類(LinearRegression, SGDRegressor等)
	ModelType string `json:"model_type"`

	// Version はモデルのバージョン(互換性チェック用)
	Version string `json:"version"`

	// Coefficients は重み係数
	Coefficients []float64 `json:"coefficients"`

	// Intercept は切片
	Intercept float64 `json:"intercept"`

	// Features は特徴量の名前(オプション)
	Features []string `json:"features,omitempty"`

	// Hyperparameters はモデルのハイパーパラメータ
	Hyperparameters map[string]interface{} `json:"hyperparameters"`

	// Metadata は追加のメタデータ(学習時の統計等)
	Metadata map[string]interface{} `json:"metadata,omitempty"`

	// IsFitted はモデルが学習済みかどうか
	IsFitted bool `json:"is_fitted"`
}

ModelWeights はモデルの重みを表す構造体(シリアライゼーション用)

func (*ModelWeights) Clone added in v0.3.0

func (mw *ModelWeights) Clone() *ModelWeights

Clone はModelWeightsのディープコピーを作成

func (*ModelWeights) FromJSON added in v0.3.0

func (mw *ModelWeights) FromJSON(data []byte) error

FromJSON はJSON形式からModelWeightsをデシリアライズ

func (*ModelWeights) ToJSON added in v0.3.0

func (mw *ModelWeights) ToJSON() ([]byte, error)

ToJSON はModelWeightsをJSON形式にシリアライズ

func (*ModelWeights) Validate added in v0.3.0

func (mw *ModelWeights) Validate() error

Validate はModelWeightsの妥当性を検証

type OnlineMetrics

type OnlineMetrics interface {
	// GetLoss は現在の損失値を返す
	GetLoss() float64

	// GetLossHistory は損失値の履歴を返す
	GetLossHistory() []float64

	// GetConverged は収束したかどうかを返す
	GetConverged() bool
}

OnlineMetrics はオンライン学習中のメトリクスを追跡するインターフェース

type ParallelStreaming

type ParallelStreaming interface {
	// SetWorkers はワーカー数を設定
	SetWorkers(n int)

	// GetWorkers は現在のワーカー数を返す
	GetWorkers() int

	// SetBatchParallelism はバッチ内並列処理の有効/無効を設定
	SetBatchParallelism(enabled bool)
}

ParallelStreaming は並列ストリーミング処理のインターフェース

type PartialFitMixin added in v0.3.0

type PartialFitMixin interface {
	// PartialFit はミニバッチで逐次学習
	PartialFit(X, y mat.Matrix, classes []int) error

	// NIterations は学習イテレーション数を返す
	NIterations() int

	// IsWarmStart はウォームスタートが有効かどうか
	IsWarmStart() bool

	// SetWarmStart はウォームスタートの有効/無効を設定
	SetWarmStart(warmStart bool)
}

PartialFitMixin は逐次学習可能なモデルのインターフェース

type PipelineCompatible added in v0.3.0

type PipelineCompatible interface {
	SKLearnCompatible

	// GetInputDim は入力次元数を返す
	GetInputDim() int

	// GetOutputDim は出力次元数を返す
	GetOutputDim() int

	// RequiresFit は学習が必要かどうか
	RequiresFit() bool
}

PipelineCompatible はパイプラインで使用可能なモデルのインターフェース

type Predictor

type Predictor interface {
	// Predict は入力データに対する予測を行う
	Predict(X mat.Matrix) (mat.Matrix, error)
}

Predictor は予測可能なモデルのインターフェース

type RegressorMixin added in v0.3.0

type RegressorMixin interface {
	Estimator

	// Score は決定係数R²を計算
	Score(X, y mat.Matrix) (float64, error)
}

RegressorMixin は回帰器のMixinインターフェース

type SKLearnCompatible added in v0.3.0

type SKLearnCompatible interface {
	// GetParams はモデルのハイパーパラメータを取得
	GetParams(deep bool) map[string]interface{}

	// SetParams はモデルのハイパーパラメータを設定
	SetParams(params map[string]interface{}) error

	// Clone はモデルの新しいインスタンスを同じパラメータで作成
	Clone() SKLearnCompatible
}

SKLearnCompatible はscikit-learn互換のインターフェース

type SKLearnLinearRegressionParams

type SKLearnLinearRegressionParams struct {
	Coefficients []float64 `json:"coefficients"` // 係数(重み)
	Intercept    float64   `json:"intercept"`    // 切片
	NFeatures    int       `json:"n_features"`   // 特徴量の数
}

SKLearnLinearRegressionParams は線形回帰モデルのパラメータ

func LoadLinearRegressionParams

func LoadLinearRegressionParams(model *SKLearnModel) (*SKLearnLinearRegressionParams, error)

LoadLinearRegressionParams はLinearRegressionのパラメータを読み込む

パラメータ:

  • model: SKLearnModelインスタンス

戻り値:

  • *SKLearnLinearRegressionParams: パラメータ
  • error: パース失敗時のエラー

type SKLearnModel

type SKLearnModel struct {
	ModelSpec SKLearnModelSpec `json:"model_spec"`
	Params    json.RawMessage  `json:"params"`
}

SKLearnModel はscikit-learnからエクスポートされたモデル

func LoadSKLearnModelFromFile

func LoadSKLearnModelFromFile(filename string) (*SKLearnModel, error)

LoadSKLearnModelFromFile はファイルからscikit-learnモデルを読み込む

パラメータ:

  • filename: JSONファイルのパス

戻り値:

  • *SKLearnModel: 読み込まれたモデル
  • error: 読み込みエラー

使用例:

model, err := model.LoadSKLearnModelFromFile("sklearn_model.json")
if err != nil {
    log.Fatal(err)
}

func LoadSKLearnModelFromReader

func LoadSKLearnModelFromReader(r io.Reader) (*SKLearnModel, error)

LoadSKLearnModelFromReader はReaderからscikit-learnモデルを読み込む

パラメータ:

  • r: JSONデータを含むReader

戻り値:

  • *SKLearnModel: 読み込まれたモデル
  • error: 読み込みエラー

type SKLearnModelSpec

type SKLearnModelSpec struct {
	Name           string `json:"name"`                      // モデル名 (e.g., "LinearRegression")
	FormatVersion  string `json:"format_version"`            // フォーマットバージョン
	SKLearnVersion string `json:"sklearn_version,omitempty"` // scikit-learnのバージョン
}

SKLearnModelSpec はscikit-learnモデルのメタデータ

type StreamingEstimator

type StreamingEstimator interface {
	IncrementalEstimator

	// FitStream はデータストリームからモデルを学習する
	// コンテキストがキャンセルされるまで、またはチャネルがクローズされるまで学習を継続
	FitStream(ctx context.Context, dataChan <-chan *Batch) error

	// PredictStream は入力ストリームに対してリアルタイム予測を行う
	// 入力チャネルがクローズされると出力チャネルもクローズされる
	PredictStream(ctx context.Context, inputChan <-chan mat.Matrix) <-chan mat.Matrix

	// FitPredictStream は学習と予測を同時に行う
	// 新しいデータで学習しながら、同時に予測も返す(test-then-train方式)
	FitPredictStream(ctx context.Context, dataChan <-chan *Batch) <-chan mat.Matrix
}

StreamingEstimator はチャネルベースのストリーミング学習を提供するインターフェース

type StreamingMetrics

type StreamingMetrics interface {
	OnlineMetrics

	// GetThroughput は現在のスループット(サンプル/秒)を返す
	GetThroughput() float64

	// GetProcessedSamples は処理されたサンプル総数を返す
	GetProcessedSamples() int64

	// GetAverageLatency は平均レイテンシ(ミリ秒)を返す
	GetAverageLatency() float64

	// GetMemoryUsage は現在のメモリ使用量(バイト)を返す
	GetMemoryUsage() int64
}

StreamingMetrics はストリーミング学習中のメトリクスを提供

type Transformer

type Transformer interface {
	// Fit は変換に必要なパラメータを学習する
	Fit(X mat.Matrix) error

	// Transform はデータを変換する
	Transform(X mat.Matrix) (mat.Matrix, error)

	// FitTransform はFitとTransformを同時に実行する
	FitTransform(X mat.Matrix) (mat.Matrix, error)
}

Transformer はデータ変換のインターフェース

type TransformerMixin added in v0.3.0

type TransformerMixin interface {
	// Transform はデータを変換
	Transform(X mat.Matrix) (mat.Matrix, error)

	// FitTransform は学習と変換を同時に実行
	FitTransform(X mat.Matrix) (mat.Matrix, error)
}

TransformerMixin は変換器のMixinインターフェース

type WeightExporter added in v0.3.0

type WeightExporter interface {
	// ExportWeights はモデルの重みをエクスポート
	ExportWeights() (*ModelWeights, error)

	// ImportWeights はモデルの重みをインポート
	ImportWeights(weights *ModelWeights) error

	// GetWeightHash は重みのハッシュ値を計算(検証用)
	GetWeightHash() string
}

WeightExporter は重みをエクスポート可能なモデルのインターフェース

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL