model

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 6, 2025 License: MIT Imports: 8 Imported by: 0

Documentation

Overview

Package model provides core abstractions and interfaces for machine learning models.

This package defines the fundamental building blocks for machine learning estimators in the SciGo library, including:

  • BaseEstimator: Core estimator with state management and serialization support
  • Model persistence: Save and load trained models using Go's encoding/gob
  • scikit-learn compatibility: Import/export models from Python scikit-learn
  • Streaming interfaces: Support for online learning and incremental training

The BaseEstimator provides a consistent foundation for all ML algorithms with:

  • Fitted state tracking to prevent usage of untrained models
  • Serialization support for model persistence
  • Thread-safe state management
  • Integration with preprocessing pipelines

Example usage:

type MyModel struct {
	model.BaseEstimator
	// model-specific fields
}

func (m *MyModel) Fit(X, y mat.Matrix) error {
	// training logic
	m.SetFitted() // mark as trained
	return nil
}

All models in SciGo embed BaseEstimator to ensure consistent behavior across the entire machine learning pipeline.

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func ExportSKLearnModel

func ExportSKLearnModel(modelName string, params interface{}, w io.Writer) error

ExportSKLearnModel はモデルをscikit-learn互換のJSON形式でエクスポート

パラメータ:

  • modelName: モデル名
  • params: モデルパラメータ
  • w: 出力先Writer

戻り値:

  • error: エクスポート失敗時のエラー

func LoadModel

func LoadModel(model interface{}, filename string) error

LoadModel はファイルからモデルを読み込む

パラメータ:

  • model: 読み込み先のモデル(BaseEstimatorを埋め込んだ構造体のポインタ)
  • filename: 読み込み元のファイルパス

戻り値:

  • error: 読み込みに失敗した場合のエラー

使用例:

var reg linear.Regression
err := model.LoadModel(&reg, "model.gob")

func LoadModelFromReader

func LoadModelFromReader(model interface{}, r io.Reader) error

LoadModelFromReader はio.Readerからモデルを読み込む

パラメータ:

  • model: 読み込み先のモデル(ポインタ)
  • r: 読み込み元のReader

戻り値:

  • error: 読み込みに失敗した場合のエラー

func SaveModel

func SaveModel(model interface{}, filename string) error

SaveModel はモデルをファイルに保存する

パラメータ:

  • model: 保存するモデル(BaseEstimatorを埋め込んだ構造体)
  • filename: 保存先のファイルパス

戻り値:

  • error: 保存に失敗した場合のエラー

使用例:

var reg linear.Regression
// ... モデルの学習 ...
err := model.SaveModel(&reg, "model.gob")

func SaveModelToWriter

func SaveModelToWriter(model interface{}, w io.Writer) error

SaveModelToWriter はモデルをio.Writerに保存する

パラメータ:

  • model: 保存するモデル
  • w: 保存先のWriter

戻り値:

  • error: 保存に失敗した場合のエラー

Types

type AdaptiveLearning

type AdaptiveLearning interface {
	// GetLearningRate は現在の学習率を返す
	GetLearningRate() float64

	// SetLearningRate は学習率を設定する
	SetLearningRate(lr float64)

	// GetLearningRateSchedule は学習率スケジュールを返す
	// "constant", "optimal", "invscaling", "adaptive" など
	GetLearningRateSchedule() string

	// SetLearningRateSchedule は学習率スケジュールを設定する
	SetLearningRateSchedule(schedule string)
}

AdaptiveLearning は学習率を動的に調整できるモデルのインターフェース

type BaseEstimator

type BaseEstimator struct {
	// State はモデルの学習状態を保持します。gobでエンコードするために公開されています。
	State EstimatorState
	// contains filtered or unexported fields
}

BaseEstimator は全てのモデルの基底となる構造体

Example

ExampleBaseEstimator demonstrates BaseEstimator state management

package main

import (
	"fmt"

	"github.com/YuminosukeSato/scigo/core/model"
)

func main() {
	// Create a BaseEstimator (typically embedded in actual models)
	estimator := &model.BaseEstimator{}

	// Check initial state
	fmt.Printf("Initially fitted: %t\n", estimator.IsFitted())

	// Mark as fitted
	estimator.SetFitted()
	fmt.Printf("After SetFitted: %t\n", estimator.IsFitted())

	// Reset to unfitted state
	estimator.Reset()
	fmt.Printf("After Reset: %t\n", estimator.IsFitted())

}
Output:

Initially fitted: false
After SetFitted: true
After Reset: false
Example (WorkflowPattern)

ExampleBaseEstimator_workflowPattern demonstrates typical usage pattern

package main

import (
	"fmt"

	"github.com/YuminosukeSato/scigo/core/model"
)

func main() {
	// This example shows how BaseEstimator is typically used in models
	type MyModel struct {
		model.BaseEstimator
		// model-specific fields would go here
	}

	myModel := &MyModel{}

	// Check if model needs training
	if !myModel.IsFitted() {
		fmt.Println("Model needs training")

		// Simulate training process
		// ... training logic would go here ...

		// Mark as fitted after successful training
		myModel.SetFitted()
		fmt.Println("Model trained successfully")
	}

	// Now model is ready for use
	if myModel.IsFitted() {
		fmt.Println("Model is ready for predictions")
	}

}
Output:

Model needs training
Model trained successfully
Model is ready for predictions

func (*BaseEstimator) GetLogger

func (e *BaseEstimator) GetLogger() interface{}

GetLogger returns the logger for this estimator. Returns nil if no logger has been set.

Returns:

  • interface{}: The logger instance, should be type-asserted to log.Logger

Example:

if logger := model.GetLogger(); logger != nil {
    if l, ok := logger.(log.Logger); ok {
        l.Info("Operation completed")
    }
}

func (*BaseEstimator) IsFitted

func (e *BaseEstimator) IsFitted() bool

IsFitted returns whether the model has been fitted with training data.

This method checks the internal state to determine if the estimator has been trained and is ready for prediction or transformation. All models must be fitted before they can be used for predictions.

Returns:

  • bool: true if the model is fitted, false otherwise

Example:

if !model.IsFitted() {
    err := model.Fit(X, y)
    if err != nil {
        log.Fatal(err)
    }
}
predictions, err := model.Predict(X_test)

func (*BaseEstimator) LogDebug

func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})

LogDebug logs a debug-level message if a logger is configured. This is a convenience method for debug logging in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs

func (*BaseEstimator) LogError

func (e *BaseEstimator) LogError(msg string, fields ...interface{})

LogError logs an error-level message if a logger is configured. This is a convenience method for error logging in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs If the first field is an error, it will be handled specially

func (*BaseEstimator) LogInfo

func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})

LogInfo logs an info-level message if a logger is configured. This is a convenience method to avoid repetitive nil checks in model implementations.

Parameters:

  • msg: The log message
  • fields: Optional structured logging fields as key-value pairs

func (*BaseEstimator) Reset

func (e *BaseEstimator) Reset()

Reset returns the estimator to its initial untrained state.

This method clears the fitted state, effectively making the model untrained. Useful for reusing model instances with different data or resetting after errors. After reset, the model must be fitted again before use.

Example:

model.Reset()
err := model.Fit(newTrainingData, newLabels)
if err != nil {
    log.Fatal(err)
}

func (*BaseEstimator) SetFitted

func (e *BaseEstimator) SetFitted()

SetFitted marks the estimator as fitted (trained).

This method is called internally by model implementations after successful training to indicate that the model is ready for predictions or transformations. Should only be called by model implementations, not by end users.

Example (within a model's Fit method):

func (m *MyModel) Fit(X, y mat.Matrix) error {
    // ... training logic ...
    m.SetFitted() // Mark as trained
    return nil
}

func (*BaseEstimator) SetLogger

func (e *BaseEstimator) SetLogger(logger interface{})

SetLogger sets the logger for this estimator. This method is typically called during model initialization to provide structured logging capabilities for ML operations.

Parameters:

  • logger: Any logger implementation (typically log.Logger interface)

Example:

import "github.com/YuminosukeSato/scigo/pkg/log"
model.SetLogger(log.GetLoggerWithName("LinearRegression"))

type Batch

type Batch struct {
	X mat.Matrix // 特徴量行列
	Y mat.Matrix // ターゲット行列
}

Batch はストリーミング学習用のデータバッチを表す

type BufferedStreaming

type BufferedStreaming interface {
	// SetBufferSize はストリーミングバッファのサイズを設定
	SetBufferSize(size int)

	// GetBufferSize は現在のバッファサイズを返す
	GetBufferSize() int

	// FlushBuffer はバッファを強制的にフラッシュ
	FlushBuffer() error
}

BufferedStreaming はバッファリング機能を持つストリーミングインターフェース

type Estimator

type Estimator interface {
	Fitter
	Predictor
}

Estimator は学習と予測の両方が可能なモデルのインターフェース

type EstimatorState

type EstimatorState int

EstimatorState はモデルの学習状態を表す

const (
	// NotFitted はモデルが未学習の状態
	NotFitted EstimatorState = iota
	// Fitted はモデルが学習済みの状態
	Fitted
)

type Fitter

type Fitter interface {
	// Fit はモデルを訓練データで学習させる
	Fit(X, y mat.Matrix) error
}

Fitter は学習可能なモデルのインターフェース

type IncrementalEstimator

type IncrementalEstimator interface {
	Estimator

	// PartialFit はミニバッチでモデルを逐次的に学習させる
	// classes は分類問題の場合に全クラスラベルを指定(最初の呼び出し時のみ必須)
	// 回帰問題の場合は nil を渡す
	PartialFit(X, y mat.Matrix, classes []int) error

	// NIterations は実行された学習イテレーション数を返す
	NIterations() int

	// WarmStart が有効かどうかを返す
	// true の場合、Fit 呼び出し時に既存のパラメータから学習を継続
	IsWarmStart() bool

	// SetWarmStart はウォームスタートの有効/無効を設定
	SetWarmStart(warmStart bool)
}

IncrementalEstimator はオンライン学習(逐次学習)可能なモデルのインターフェース scikit-learnのpartial_fit APIと互換性を持つ

type LinearModel

type LinearModel interface {
	// Weights は学習された重み(係数)を返す
	Weights() []float64
	// Intercept は学習された切片を返す
	Intercept() float64
	// Score はモデルの決定係数(R²)を計算する
	Score(X, y mat.Matrix) (float64, error)
}

LinearModel は線形モデルのインターフェース

type OnlineMetrics

type OnlineMetrics interface {
	// GetLoss は現在の損失値を返す
	GetLoss() float64

	// GetLossHistory は損失値の履歴を返す
	GetLossHistory() []float64

	// GetConverged は収束したかどうかを返す
	GetConverged() bool
}

OnlineMetrics はオンライン学習中のメトリクスを追跡するインターフェース

type ParallelStreaming

type ParallelStreaming interface {
	// SetWorkers はワーカー数を設定
	SetWorkers(n int)

	// GetWorkers は現在のワーカー数を返す
	GetWorkers() int

	// SetBatchParallelism はバッチ内並列処理の有効/無効を設定
	SetBatchParallelism(enabled bool)
}

ParallelStreaming は並列ストリーミング処理のインターフェース

type Predictor

type Predictor interface {
	// Predict は入力データに対する予測を行う
	Predict(X mat.Matrix) (mat.Matrix, error)
}

Predictor は予測可能なモデルのインターフェース

type SKLearnLinearRegressionParams

type SKLearnLinearRegressionParams struct {
	Coefficients []float64 `json:"coefficients"` // 係数(重み)
	Intercept    float64   `json:"intercept"`    // 切片
	NFeatures    int       `json:"n_features"`   // 特徴量の数
}

SKLearnLinearRegressionParams は線形回帰モデルのパラメータ

func LoadLinearRegressionParams

func LoadLinearRegressionParams(model *SKLearnModel) (*SKLearnLinearRegressionParams, error)

LoadLinearRegressionParams はLinearRegressionのパラメータを読み込む

パラメータ:

  • model: SKLearnModelインスタンス

戻り値:

  • *SKLearnLinearRegressionParams: パラメータ
  • error: パース失敗時のエラー

type SKLearnModel

type SKLearnModel struct {
	ModelSpec SKLearnModelSpec `json:"model_spec"`
	Params    json.RawMessage  `json:"params"`
}

SKLearnModel はscikit-learnからエクスポートされたモデル

func LoadSKLearnModelFromFile

func LoadSKLearnModelFromFile(filename string) (*SKLearnModel, error)

LoadSKLearnModelFromFile はファイルからscikit-learnモデルを読み込む

パラメータ:

  • filename: JSONファイルのパス

戻り値:

  • *SKLearnModel: 読み込まれたモデル
  • error: 読み込みエラー

使用例:

model, err := model.LoadSKLearnModelFromFile("sklearn_model.json")
if err != nil {
    log.Fatal(err)
}

func LoadSKLearnModelFromReader

func LoadSKLearnModelFromReader(r io.Reader) (*SKLearnModel, error)

LoadSKLearnModelFromReader はReaderからscikit-learnモデルを読み込む

パラメータ:

  • r: JSONデータを含むReader

戻り値:

  • *SKLearnModel: 読み込まれたモデル
  • error: 読み込みエラー

type SKLearnModelSpec

type SKLearnModelSpec struct {
	Name           string `json:"name"`                      // モデル名 (e.g., "LinearRegression")
	FormatVersion  string `json:"format_version"`            // フォーマットバージョン
	SKLearnVersion string `json:"sklearn_version,omitempty"` // scikit-learnのバージョン
}

SKLearnModelSpec はscikit-learnモデルのメタデータ

type StreamingEstimator

type StreamingEstimator interface {
	IncrementalEstimator

	// FitStream はデータストリームからモデルを学習する
	// コンテキストがキャンセルされるまで、またはチャネルがクローズされるまで学習を継続
	FitStream(ctx context.Context, dataChan <-chan *Batch) error

	// PredictStream は入力ストリームに対してリアルタイム予測を行う
	// 入力チャネルがクローズされると出力チャネルもクローズされる
	PredictStream(ctx context.Context, inputChan <-chan mat.Matrix) <-chan mat.Matrix

	// FitPredictStream は学習と予測を同時に行う
	// 新しいデータで学習しながら、同時に予測も返す(test-then-train方式)
	FitPredictStream(ctx context.Context, dataChan <-chan *Batch) <-chan mat.Matrix
}

StreamingEstimator はチャネルベースのストリーミング学習を提供するインターフェース

type StreamingMetrics

type StreamingMetrics interface {
	OnlineMetrics

	// GetThroughput は現在のスループット(サンプル/秒)を返す
	GetThroughput() float64

	// GetProcessedSamples は処理されたサンプル総数を返す
	GetProcessedSamples() int64

	// GetAverageLatency は平均レイテンシ(ミリ秒)を返す
	GetAverageLatency() float64

	// GetMemoryUsage は現在のメモリ使用量(バイト)を返す
	GetMemoryUsage() int64
}

StreamingMetrics はストリーミング学習中のメトリクスを提供

type Transformer

type Transformer interface {
	// Fit は変換に必要なパラメータを学習する
	Fit(X mat.Matrix) error

	// Transform はデータを変換する
	Transform(X mat.Matrix) (mat.Matrix, error)

	// FitTransform はFitとTransformを同時に実行する
	FitTransform(X mat.Matrix) (mat.Matrix, error)
}

Transformer はデータ変換のインターフェース

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL