Documentation
¶
Overview ¶
Package model provides core abstractions and interfaces for machine learning models.
This package defines the fundamental building blocks for machine learning estimators in the SciGo library, including:
- BaseEstimator: Core estimator with state management and serialization support
- Model persistence: Save and load trained models using Go's encoding/gob
- scikit-learn compatibility: Import/export models from Python scikit-learn
- Streaming interfaces: Support for online learning and incremental training
The BaseEstimator provides a consistent foundation for all ML algorithms with:
- Fitted state tracking to prevent usage of untrained models
- Serialization support for model persistence
- Thread-safe state management
- Integration with preprocessing pipelines
Example usage:
type MyModel struct {
model.BaseEstimator
// model-specific fields
}
func (m *MyModel) Fit(X, y mat.Matrix) error {
// training logic
m.SetFitted() // mark as trained
return nil
}
All models in SciGo embed BaseEstimator to ensure consistent behavior across the entire machine learning pipeline.
Index ¶
- func ExportSKLearnModel(modelName string, params interface{}, w io.Writer) error
- func LoadModel(model interface{}, filename string) error
- func LoadModelFromReader(model interface{}, r io.Reader) error
- func SaveModel(model interface{}, filename string) error
- func SaveModelToWriter(model interface{}, w io.Writer) error
- type AdaptiveLearning
- type BaseEstimator
- func (e *BaseEstimator) GetLogger() interface{}
- func (e *BaseEstimator) IsFitted() bool
- func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})
- func (e *BaseEstimator) LogError(msg string, fields ...interface{})
- func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})
- func (e *BaseEstimator) Reset()
- func (e *BaseEstimator) SetFitted()
- func (e *BaseEstimator) SetLogger(logger interface{})
- type Batch
- type BufferedStreaming
- type Estimator
- type EstimatorState
- type Fitter
- type IncrementalEstimator
- type LinearModel
- type OnlineMetrics
- type ParallelStreaming
- type Predictor
- type SKLearnLinearRegressionParams
- type SKLearnModel
- type SKLearnModelSpec
- type StreamingEstimator
- type StreamingMetrics
- type Transformer
Examples ¶
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ExportSKLearnModel ¶
ExportSKLearnModel はモデルをscikit-learn互換のJSON形式でエクスポート
パラメータ:
- modelName: モデル名
- params: モデルパラメータ
- w: 出力先Writer
戻り値:
- error: エクスポート失敗時のエラー
func LoadModel ¶
LoadModel はファイルからモデルを読み込む
パラメータ:
- model: 読み込み先のモデル(BaseEstimatorを埋め込んだ構造体のポインタ)
- filename: 読み込み元のファイルパス
戻り値:
- error: 読み込みに失敗した場合のエラー
使用例:
var reg linear.Regression err := model.LoadModel(®, "model.gob")
func LoadModelFromReader ¶
LoadModelFromReader はio.Readerからモデルを読み込む
パラメータ:
- model: 読み込み先のモデル(ポインタ)
- r: 読み込み元のReader
戻り値:
- error: 読み込みに失敗した場合のエラー
func SaveModel ¶
SaveModel はモデルをファイルに保存する
パラメータ:
- model: 保存するモデル(BaseEstimatorを埋め込んだ構造体)
- filename: 保存先のファイルパス
戻り値:
- error: 保存に失敗した場合のエラー
使用例:
var reg linear.Regression // ... モデルの学習 ... err := model.SaveModel(®, "model.gob")
func SaveModelToWriter ¶
SaveModelToWriter はモデルをio.Writerに保存する
パラメータ:
- model: 保存するモデル
- w: 保存先のWriter
戻り値:
- error: 保存に失敗した場合のエラー
Types ¶
type AdaptiveLearning ¶
type AdaptiveLearning interface {
// GetLearningRate は現在の学習率を返す
GetLearningRate() float64
// SetLearningRate は学習率を設定する
SetLearningRate(lr float64)
// GetLearningRateSchedule は学習率スケジュールを返す
// "constant", "optimal", "invscaling", "adaptive" など
GetLearningRateSchedule() string
// SetLearningRateSchedule は学習率スケジュールを設定する
SetLearningRateSchedule(schedule string)
}
AdaptiveLearning は学習率を動的に調整できるモデルのインターフェース
type BaseEstimator ¶
type BaseEstimator struct {
// State はモデルの学習状態を保持します。gobでエンコードするために公開されています。
State EstimatorState
// contains filtered or unexported fields
}
BaseEstimator は全てのモデルの基底となる構造体
Example ¶
ExampleBaseEstimator demonstrates BaseEstimator state management
package main
import (
"fmt"
"github.com/YuminosukeSato/scigo/core/model"
)
func main() {
// Create a BaseEstimator (typically embedded in actual models)
estimator := &model.BaseEstimator{}
// Check initial state
fmt.Printf("Initially fitted: %t\n", estimator.IsFitted())
// Mark as fitted
estimator.SetFitted()
fmt.Printf("After SetFitted: %t\n", estimator.IsFitted())
// Reset to unfitted state
estimator.Reset()
fmt.Printf("After Reset: %t\n", estimator.IsFitted())
}
Output: Initially fitted: false After SetFitted: true After Reset: false
Example (WorkflowPattern) ¶
ExampleBaseEstimator_workflowPattern demonstrates typical usage pattern
package main
import (
"fmt"
"github.com/YuminosukeSato/scigo/core/model"
)
func main() {
// This example shows how BaseEstimator is typically used in models
type MyModel struct {
model.BaseEstimator
// model-specific fields would go here
}
myModel := &MyModel{}
// Check if model needs training
if !myModel.IsFitted() {
fmt.Println("Model needs training")
// Simulate training process
// ... training logic would go here ...
// Mark as fitted after successful training
myModel.SetFitted()
fmt.Println("Model trained successfully")
}
// Now model is ready for use
if myModel.IsFitted() {
fmt.Println("Model is ready for predictions")
}
}
Output: Model needs training Model trained successfully Model is ready for predictions
func (*BaseEstimator) GetLogger ¶
func (e *BaseEstimator) GetLogger() interface{}
GetLogger returns the logger for this estimator. Returns nil if no logger has been set.
Returns:
- interface{}: The logger instance, should be type-asserted to log.Logger
Example:
if logger := model.GetLogger(); logger != nil {
if l, ok := logger.(log.Logger); ok {
l.Info("Operation completed")
}
}
func (*BaseEstimator) IsFitted ¶
func (e *BaseEstimator) IsFitted() bool
IsFitted returns whether the model has been fitted with training data.
This method checks the internal state to determine if the estimator has been trained and is ready for prediction or transformation. All models must be fitted before they can be used for predictions.
Returns:
- bool: true if the model is fitted, false otherwise
Example:
if !model.IsFitted() {
err := model.Fit(X, y)
if err != nil {
log.Fatal(err)
}
}
predictions, err := model.Predict(X_test)
func (*BaseEstimator) LogDebug ¶
func (e *BaseEstimator) LogDebug(msg string, fields ...interface{})
LogDebug logs a debug-level message if a logger is configured. This is a convenience method for debug logging in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs
func (*BaseEstimator) LogError ¶
func (e *BaseEstimator) LogError(msg string, fields ...interface{})
LogError logs an error-level message if a logger is configured. This is a convenience method for error logging in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs If the first field is an error, it will be handled specially
func (*BaseEstimator) LogInfo ¶
func (e *BaseEstimator) LogInfo(msg string, fields ...interface{})
LogInfo logs an info-level message if a logger is configured. This is a convenience method to avoid repetitive nil checks in model implementations.
Parameters:
- msg: The log message
- fields: Optional structured logging fields as key-value pairs
func (*BaseEstimator) Reset ¶
func (e *BaseEstimator) Reset()
Reset returns the estimator to its initial untrained state.
This method clears the fitted state, effectively making the model untrained. Useful for reusing model instances with different data or resetting after errors. After reset, the model must be fitted again before use.
Example:
model.Reset()
err := model.Fit(newTrainingData, newLabels)
if err != nil {
log.Fatal(err)
}
func (*BaseEstimator) SetFitted ¶
func (e *BaseEstimator) SetFitted()
SetFitted marks the estimator as fitted (trained).
This method is called internally by model implementations after successful training to indicate that the model is ready for predictions or transformations. Should only be called by model implementations, not by end users.
Example (within a model's Fit method):
func (m *MyModel) Fit(X, y mat.Matrix) error {
// ... training logic ...
m.SetFitted() // Mark as trained
return nil
}
func (*BaseEstimator) SetLogger ¶
func (e *BaseEstimator) SetLogger(logger interface{})
SetLogger sets the logger for this estimator. This method is typically called during model initialization to provide structured logging capabilities for ML operations.
Parameters:
- logger: Any logger implementation (typically log.Logger interface)
Example:
import "github.com/YuminosukeSato/scigo/pkg/log"
model.SetLogger(log.GetLoggerWithName("LinearRegression"))
type BufferedStreaming ¶
type BufferedStreaming interface {
// SetBufferSize はストリーミングバッファのサイズを設定
SetBufferSize(size int)
// GetBufferSize は現在のバッファサイズを返す
GetBufferSize() int
// FlushBuffer はバッファを強制的にフラッシュ
FlushBuffer() error
}
BufferedStreaming はバッファリング機能を持つストリーミングインターフェース
type EstimatorState ¶
type EstimatorState int
EstimatorState はモデルの学習状態を表す
const ( // NotFitted はモデルが未学習の状態 NotFitted EstimatorState = iota // Fitted はモデルが学習済みの状態 Fitted )
type IncrementalEstimator ¶
type IncrementalEstimator interface {
Estimator
// PartialFit はミニバッチでモデルを逐次的に学習させる
// classes は分類問題の場合に全クラスラベルを指定(最初の呼び出し時のみ必須)
// 回帰問題の場合は nil を渡す
PartialFit(X, y mat.Matrix, classes []int) error
// NIterations は実行された学習イテレーション数を返す
NIterations() int
// WarmStart が有効かどうかを返す
// true の場合、Fit 呼び出し時に既存のパラメータから学習を継続
IsWarmStart() bool
// SetWarmStart はウォームスタートの有効/無効を設定
SetWarmStart(warmStart bool)
}
IncrementalEstimator はオンライン学習(逐次学習)可能なモデルのインターフェース scikit-learnのpartial_fit APIと互換性を持つ
type LinearModel ¶
type LinearModel interface {
// Weights は学習された重み(係数)を返す
Weights() []float64
// Intercept は学習された切片を返す
Intercept() float64
// Score はモデルの決定係数(R²)を計算する
Score(X, y mat.Matrix) (float64, error)
}
LinearModel は線形モデルのインターフェース
type OnlineMetrics ¶
type OnlineMetrics interface {
// GetLoss は現在の損失値を返す
GetLoss() float64
// GetLossHistory は損失値の履歴を返す
GetLossHistory() []float64
// GetConverged は収束したかどうかを返す
GetConverged() bool
}
OnlineMetrics はオンライン学習中のメトリクスを追跡するインターフェース
type ParallelStreaming ¶
type ParallelStreaming interface {
// SetWorkers はワーカー数を設定
SetWorkers(n int)
// GetWorkers は現在のワーカー数を返す
GetWorkers() int
// SetBatchParallelism はバッチ内並列処理の有効/無効を設定
SetBatchParallelism(enabled bool)
}
ParallelStreaming は並列ストリーミング処理のインターフェース
type SKLearnLinearRegressionParams ¶
type SKLearnLinearRegressionParams struct {
Coefficients []float64 `json:"coefficients"` // 係数(重み)
Intercept float64 `json:"intercept"` // 切片
NFeatures int `json:"n_features"` // 特徴量の数
}
SKLearnLinearRegressionParams は線形回帰モデルのパラメータ
func LoadLinearRegressionParams ¶
func LoadLinearRegressionParams(model *SKLearnModel) (*SKLearnLinearRegressionParams, error)
LoadLinearRegressionParams はLinearRegressionのパラメータを読み込む
パラメータ:
- model: SKLearnModelインスタンス
戻り値:
- *SKLearnLinearRegressionParams: パラメータ
- error: パース失敗時のエラー
type SKLearnModel ¶
type SKLearnModel struct {
ModelSpec SKLearnModelSpec `json:"model_spec"`
Params json.RawMessage `json:"params"`
}
SKLearnModel はscikit-learnからエクスポートされたモデル
func LoadSKLearnModelFromFile ¶
func LoadSKLearnModelFromFile(filename string) (*SKLearnModel, error)
LoadSKLearnModelFromFile はファイルからscikit-learnモデルを読み込む
パラメータ:
- filename: JSONファイルのパス
戻り値:
- *SKLearnModel: 読み込まれたモデル
- error: 読み込みエラー
使用例:
model, err := model.LoadSKLearnModelFromFile("sklearn_model.json")
if err != nil {
log.Fatal(err)
}
func LoadSKLearnModelFromReader ¶
func LoadSKLearnModelFromReader(r io.Reader) (*SKLearnModel, error)
LoadSKLearnModelFromReader はReaderからscikit-learnモデルを読み込む
パラメータ:
- r: JSONデータを含むReader
戻り値:
- *SKLearnModel: 読み込まれたモデル
- error: 読み込みエラー
type SKLearnModelSpec ¶
type SKLearnModelSpec struct {
Name string `json:"name"` // モデル名 (e.g., "LinearRegression")
FormatVersion string `json:"format_version"` // フォーマットバージョン
SKLearnVersion string `json:"sklearn_version,omitempty"` // scikit-learnのバージョン
}
SKLearnModelSpec はscikit-learnモデルのメタデータ
type StreamingEstimator ¶
type StreamingEstimator interface {
IncrementalEstimator
// FitStream はデータストリームからモデルを学習する
// コンテキストがキャンセルされるまで、またはチャネルがクローズされるまで学習を継続
FitStream(ctx context.Context, dataChan <-chan *Batch) error
// PredictStream は入力ストリームに対してリアルタイム予測を行う
// 入力チャネルがクローズされると出力チャネルもクローズされる
PredictStream(ctx context.Context, inputChan <-chan mat.Matrix) <-chan mat.Matrix
// FitPredictStream は学習と予測を同時に行う
// 新しいデータで学習しながら、同時に予測も返す(test-then-train方式)
FitPredictStream(ctx context.Context, dataChan <-chan *Batch) <-chan mat.Matrix
}
StreamingEstimator はチャネルベースのストリーミング学習を提供するインターフェース
type StreamingMetrics ¶
type StreamingMetrics interface {
OnlineMetrics
// GetThroughput は現在のスループット(サンプル/秒)を返す
GetThroughput() float64
// GetProcessedSamples は処理されたサンプル総数を返す
GetProcessedSamples() int64
// GetAverageLatency は平均レイテンシ(ミリ秒)を返す
GetAverageLatency() float64
// GetMemoryUsage は現在のメモリ使用量(バイト)を返す
GetMemoryUsage() int64
}
StreamingMetrics はストリーミング学習中のメトリクスを提供