README
ยถ
LOOM WASM Module
WebAssembly bindings for the LOOM neural network framework, enabling neural network creation, training, and transformer inference directly in the browser with zero dependencies.
๐ Quick Start
cd wasm
./build_wasm.sh # Build main.wasm
python3 -m http.server 8080
# Open http://localhost:8080/grid_scatter_demo.html
๐ NEW: Simple API
Streamlined functions for common operations with cross-platform consistency:
// Create network from JSON
const config = {
batch_size: 1,
grid_rows: 1,
grid_cols: 3,
layers_per_cell: 1,
layers: [
{ type: "dense", input_size: 8, output_size: 16, activation: "relu" },
{
type: "parallel",
combine_mode: "grid_scatter",
grid_output_rows: 3,
grid_output_cols: 1,
grid_output_layers: 1,
grid_positions: [
{ branch_index: 0, target_row: 0, target_col: 0, target_layer: 0 },
{ branch_index: 1, target_row: 1, target_col: 0, target_layer: 0 },
{ branch_index: 2, target_row: 2, target_col: 0, target_layer: 0 },
],
branches: [
{
type: "parallel",
combine_mode: "add",
branches: [
{
type: "dense",
input_size: 16,
output_size: 8,
activation: "relu",
},
{
type: "dense",
input_size: 16,
output_size: 8,
activation: "gelu",
},
],
},
{ type: "lstm", input_size: 16, hidden_size: 8, seq_length: 1 },
{ type: "rnn", input_size: 16, hidden_size: 8, seq_length: 1 },
],
},
{ type: "dense", input_size: 24, output_size: 2, activation: "sigmoid" },
],
};
const network = createNetworkFromJSON(JSON.stringify(config));
// Training
const batches = [
{ Input: [0.2, 0.2, 0.2, 0.2, 0.8, 0.8, 0.8, 0.8], Target: [1.0, 0.0] },
{ Input: [0.9, 0.9, 0.9, 0.9, 0.1, 0.1, 0.1, 0.1], Target: [0.0, 1.0] },
{ Input: [0.7, 0.7, 0.7, 0.7, 0.3, 0.3, 0.3, 0.3], Target: [0.0, 1.0] },
{ Input: [0.3, 0.3, 0.3, 0.3, 0.7, 0.7, 0.7, 0.7], Target: [1.0, 0.0] },
];
const trainingConfig = {
Epochs: 800,
LearningRate: 0.15,
UseGPU: false,
PrintEveryBatch: 0,
GradientClip: 1.0,
LossType: "mse",
Verbose: false,
};
const [result, error] = network.Train(
JSON.stringify([batches, trainingConfig])
);
console.log("Training complete! Final loss:", JSON.parse(result).FinalLoss);
// Forward pass
const [output] = network.ForwardCPU(
JSON.stringify([[0.2, 0.2, 0.2, 0.2, 0.8, 0.8, 0.8, 0.8]])
);
console.log("Output:", JSON.parse(output)); // [0.950, 0.050]
// Evaluate network
const inputs = batches.map((b) => b.Input);
const expected = [0, 1, 1, 0];
const [metrics] = network.EvaluateNetwork(JSON.stringify([inputs, expected]));
const metricsData = JSON.parse(metrics);
console.log(
`Quality Score: ${metricsData.score}/100, Avg Deviation: ${metricsData.avg_deviation}%`
);
// Save/Load
const [modelJSON] = network.SaveModelToString(JSON.stringify(["my_model"]));
console.log(`Model saved (${modelJSON.length} bytes)`);
// Load model
const loadedNetwork = loadLoomNetwork(modelJSON, "my_model");
const [output2] = loadedNetwork.ForwardCPU(
JSON.stringify([[0.2, 0.2, 0.2, 0.2, 0.8, 0.8, 0.8, 0.8]])
);
// output2 === output (bit-for-bit identical!)
Simple API Functions:
createNetworkFromJSON(jsonConfig)- Create from JSON configurationloadLoomNetwork(jsonString, modelID)- Load saved modelnetwork.ForwardCPU(inputJSON)- Forward passnetwork.BackwardCPU(gradientsJSON)- Backward passnetwork.Train(batchesJSON)- Train networknetwork.SaveModelToString(idJSON)- Save to JSON stringnetwork.EvaluateNetwork(inputsJSON)- Evaluate with metricsnetwork.UpdateWeights(lrJSON)- Update weights
Cross-Platform Consistency: The simple API matches Python, TypeScript, C#, and C - identical behavior and results!
See grid_scatter_demo.html and grid_scatter_demo.js for complete working examples.
โก Stepping API - Fine-Grained Execution Control
NEW: Execute networks one step at a time for online learning in the browser:
// Create network
const config = { batch_size: 1, layers: [
{ type: "dense", input_height: 4, output_height: 8, activation: "relu" },
{ type: "lstm", input_size: 8, hidden_size: 12, seq_length: 1 },
{ type: "dense", input_height: 12, output_height: 3, activation: "softmax" }
]};
const network = createLoomNetwork(JSON.stringify(config));
// Initialize stepping state
const state = network.createStepState(4);
// Training loop - update weights after EACH step
for (let step = 0; step < 100000; step++) {
state.setInput(new Float32Array([0.1, 0.2, 0.1, 0.3]));
state.stepForward();
const output = state.getOutput();
// Calculate gradients
const gradients = new Float32Array(output.length);
for (let i = 0; i < output.length; i++)
gradients[i] = output[i] - target[i];
// Backward pass
state.stepBackward(gradients);
// Update weights immediately
network.ApplyGradients(JSON.stringify([learningRate]));
}
Stepping API:
network.createStepState(inputSize)- Initialize stepping statestate.setInput(data)- Set input for current stepstate.stepForward()- Execute forward passstate.getOutput()- Get output from last layerstate.stepBackward(gradients)- Execute backward passnetwork.ApplyGradients(paramsJSON)- Update network weights
See step_example.html for a complete interactive example achieving 100% accuracy.
๐งฌ Neural Tweening API - Gradient-Free Learning
NEW: Use Neural Tweening for adaptive learning without explicit backpropagation:
// Create network from JSON config
const config = { batch_size: 1, layers: [
{ type: "dense", input_size: 8, output_size: 32, activation: "leaky_relu" },
{ type: "dense", input_size: 32, output_size: 4, activation: "sigmoid" }
]};
const network = createLoomNetwork(JSON.stringify(config));
// Create tween state (true = use chain rule for better gradients)
const tweenState = network.createTweenState(true);
// Training loop
for (let step = 0; step < 10000; step++) {
const input = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
const targetClass = 2; // Target output neuron
const outputSize = 4;
const learningRate = 0.02;
// Single call does forward, backward, and weight update
const loss = tweenState.TweenStep(input, targetClass, outputSize, learningRate);
}
Tweening API:
network.createTweenState(useChainRule)- Create tween state (chainRule=true recommended)tweenState.TweenStep(input, targetClass, outputSize, lr)- Complete learning step, returns losstweenState.setChainRule(enabled)- Enable/disable chain rule modetweenState.getChainRule()- Check current chain rule settingtweenState.getTweenSteps()- Get number of tween steps performed
๐ Adaptation Tracker API - Benchmark Mid-Stream Task Changes
Track accuracy over time with scheduled task changes (like Test 18):
// Create tracker: 1-second windows, 10-second total duration
const tracker = createAdaptationTracker(1000, 10000);
tracker.setModelInfo("Dense-5L", "TweenChain");
// Schedule task changes at 1/3 and 2/3 of duration
tracker.scheduleTaskChange(3333, 1, "AVOID"); // Switch to task 1 at 3.3s
tracker.scheduleTaskChange(6666, 0, "CHASE"); // Switch back to task 0 at 6.6s
// Start tracking
tracker.start("CHASE", 0);
// During training loop
const currentTask = tracker.getCurrentTask(); // Returns 0 or 1
// ... do inference and training ...
tracker.recordOutput(isCorrect); // Record each prediction
// Get results
const resultsJSON = tracker.finalize();
const results = JSON.parse(resultsJSON);
console.log(`Avg Accuracy: ${results.avg_accuracy}%`);
console.log(`Task Changes:`, results.task_changes);
Tracker API:
createAdaptationTracker(windowMs, totalMs)- Create trackertracker.setModelInfo(modelName, modeName)- Set descriptive namestracker.scheduleTaskChange(atMs, taskID, taskName)- Schedule task switchtracker.start(initialTask, initialTaskID)- Begin trackingtracker.getCurrentTask()- Get current task ID (handles scheduled changes)tracker.recordOutput(isCorrect)- Record prediction resulttracker.finalize()- Get JSON results with windows and task change metrics
See adaptation_demo.html for a complete architecture adaptation benchmark.
๐ What's New: Dynamic Method Exposure
Every Network method automatically available in JavaScript!
The WASM wrapper uses Go reflection to dynamically expose ALL nn.Network methods - no manual bindings required!
// Create network from JSON
const network = createLoomNetwork(jsonConfig);
// All 27+ methods automatically available:
network.ForwardCPU(inputJSON);
network.BackwardCPU(gradientsJSON);
network.Train(batchesJSON);
network.SaveModelToString(idJSON);
network.GetMethodsJSON();
// ... and 20+ more!
โจ Features
- โ Zero Manual Bindings: All Network methods auto-exposed via reflection
- โ 27+ Methods: Complete API including Train, Forward, Backward, SaveModel, LoadModel
- โ JSON-Based Network Creation: Build networks from JSON config (no Go code needed)
- โ
Full Training Support:
Train(batches, config)with automatic gradient computation - โ All Layer Types: Dense, Conv2D, LSTM, RNN, MHA, Parallel, Grid Scatter, Softmax (10+ variants)
- โ Grid Scatter Demo: Multi-agent heterogeneous neural networks training in browser
- โ Runtime Introspection: Query available methods and signatures
- โ Type Conversion: Automatic JavaScript โ Go type conversion
- โ 6.4MB Binary: Complete framework in single WASM module
- โ Pure CPU: All operations (GPU via WebGPU coming soon)
๐ API Overview
Creating Networks from JSON
// Define network architecture
const config = {
input_size: 10,
batch_size: 1,
grid_rows: 1,
grid_cols: 1,
layers_per_cell: 3,
layers: [
{
type: "dense",
output_size: 8,
activation: "relu",
},
{
type: "dense",
output_size: 4,
activation: "relu",
},
{
type: "dense",
output_size: 2,
activation: "sigmoid",
},
],
};
// Create network (returns object with ALL methods)
const network = createLoomNetwork(JSON.stringify(config));
Training Networks
// Prepare training data
const batches = [
{
Input: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
Target: [1.0, 0.0],
},
// ... more samples
];
// Training configuration
const config = {
Epochs: 100,
LearningRate: 0.01,
UseGPU: false,
PrintEveryBatch: 0,
GradientClip: 1.0,
LossType: "mse",
Verbose: false,
};
// Train the network
const result = JSON.parse(network.Train(JSON.stringify([batches, config])));
console.log("Initial Loss:", result[0].LossHistory[0]);
console.log("Final Loss:", result[0].FinalLoss);
console.log(
"Improvement:",
(
((result[0].LossHistory[0] - result[0].FinalLoss) /
result[0].LossHistory[0]) *
100
).toFixed(2) + "%"
);
Forward Pass
// Run inference
const input = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
const result = JSON.parse(network.ForwardCPU(JSON.stringify([input])));
const output = result[0]; // Output array
const duration = result[1]; // Execution time (nanoseconds)
console.log("Predictions:", output);
Method Introspection
// List all available methods
const methods = JSON.parse(network.ListMethods())[0];
console.log("Available methods:", methods);
// Output: ["Activations", "BackwardCPU", "ForwardCPU", "GetMethodsJSON",
// "InitializeWeights", "ListMethods", "SaveModelToString", "Train", ...]
// Get detailed method information
const methodsJSON = JSON.parse(network.GetMethodsJSON())[0];
const parsedMethods = JSON.parse(methodsJSON);
parsedMethods.forEach((method) => {
console.log(
`${method.method_name}(${method.parameters
.map((p) => p.type)
.join(", ")}) -> ${method.returns.join(", ")}`
);
});
Save/Load Models
// Save model to JSON string
const modelJSON = JSON.parse(
network.SaveModelToString(JSON.stringify(["my_model"]))
)[0];
// Store in localStorage
localStorage.setItem("loom_model", modelJSON);
// Load model later
const savedModel = localStorage.getItem("loom_model");
const loadedNetwork = createLoomNetwork(savedModel);
// Use loaded network immediately
const output = JSON.parse(loadedNetwork.ForwardCPU(JSON.stringify([input])))[0];
๐ฎ Interactive Demos
test.html - Complete Neural Network Demo
Features:
- ๐จ Beautiful gradient UI with multiple test sections
- ๐ง JSON config editor for network architecture
- ๐๏ธ Training demo with pattern recognition
- ๐ค Grid Scatter multi-agent training
- ๐ Real-time loss tracking and predictions
- ๐พ Method introspection and network info
Test 1: Grid Scatter - Multi-Agent Coordination
- 3 heterogeneous agents (Feature Extractor, LSTM, RNN)
- Binary classification task
- 800 epochs training in ~0.4 seconds
- 99.5% improvement (0.25 โ 0.001 loss)
- 100% classification accuracy
Test 2: Pattern Recognition Training
- Learns to classify patterns: high values in first half vs second half
- Configurable epochs, learning rate, and sample count
- Real-time loss display
- Post-training validation
Try it:
cd wasm
python3 -m http.server 8080
# Open http://localhost:8080/test.html
Example Output: Grid Scatter Training
๐ค Running Grid Scatter Multi-Agent Training...
Task: 3 agents learn to collaborate for binary classification
Architecture:
Shared Layer โ Grid Scatter (3 agents) โ Decision
Agent 0: Feature Extractor (ensemble of 2 dense)
Agent 1: Transformer (LSTM)
Agent 2: Integrator (RNN)
โ
Training complete!
Training time: 0.36 seconds
Initial Loss: 0.252357
Final Loss: 0.001175
Improvement: 99.53%
Final predictions:
Sample 0: [0.989, 0.011] โ Class 0 (expected 0) โ
Sample 1: [0.023, 0.977] โ Class 1 (expected 1) โ
Sample 2: [0.049, 0.951] โ Class 1 (expected 1) โ
Sample 3: [0.960, 0.040] โ Class 0 (expected 0) โ
๐๏ธ Architecture
Dynamic Method Wrapping
The WASM module (wasm/main.go) uses reflection to expose methods:
- Network Creation:
createLoomNetwork(jsonConfig)builds network from JSON - Method Discovery: Uses
reflect.ValueOf(network)to find all methods - Dynamic Wrapping: Each method wrapped with
js.FuncOf(methodWrapper) - Type Conversion: Automatic JSON โ Go type conversion
- Result Serialization: All results returned as JSON arrays
// Pseudo-code of the wrapper
func createNetworkFromJSON(jsonConfig string) js.Value {
network := nn.BuildNetworkFromJSON(jsonConfig)
networkObj := js.Global().Get("Object").New()
// Wrap EVERY method dynamically
networkValue := reflect.ValueOf(network)
for i := 0; i < networkValue.NumMethod(); i++ {
method := networkValue.Type().Method(i)
networkObj.Set(method.Name, js.FuncOf(methodWrapper(network, method.Name)))
}
return networkObj
}
Key Components
main.go: WASM entry point, reflection-based method exposuregrid_scatter_demo.js: Multi-agent training demotest.html: Interactive UI with all demosbuild_wasm.sh: Build scriptnn/introspection.go: Method discovery and signature extraction
๐ง Available Network Methods
All 27+ methods automatically exposed:
Core Operations:
ForwardCPU([]float32)- Forward passBackwardCPU([]float32)- Backward passUpdateWeights(float32)- Update weights with learning rateTrain([]TrainingBatch, *TrainingConfig)- Full training loop
Model Management:
SaveModelToString(string)- Export model as JSONLoadModelFromString(string, string)- Import model from JSONInitializeWeights()- Initialize random weightsResetState()- Reset RNN/LSTM hidden states
Introspection:
GetMethodsJSON()- Get all method signaturesListMethods()- Get method names onlyGetMethodSignature(string)- Get specific method signatureHasMethod(string)- Check if method exists
Layer Operations:
GetLayer(int, int, int)- Get layer configSetLayer(int, int, int, LayerConfig)- Set layer configTotalLayers()- Get layer countActivations()- Get activation outputs
And many more! Use network.ListMethods() to see all available methods.
๐ Supported Layer Types
All layer types from the Go framework:
- Dense: Fully connected layers with 15+ activation functions
- Conv2D: 2D convolution (CPU implementation)
- LSTM: Long Short-Term Memory
- RNN: Recurrent Neural Network
- Multi-Head Attention: Transformer attention mechanism
- Layer Norm: Layer normalization
- RMS Norm: Root Mean Square normalization
- SwiGLU: Gated linear units
- Softmax: 10+ variants (standard, temperature, grid, spatial, etc.)
- Parallel: 4 combine modes (concat, add, avg, grid_scatter)
๐ฏ Grid Scatter: Multi-Agent Networks
What makes this special:
Grid Scatter enables heterogeneous multi-agent neural networks where each agent has a completely different architecture:
{
type: "parallel",
combine_mode: "grid_scatter",
grid_output_rows: 3,
grid_output_cols: 1,
branches: [
{
type: "parallel",
combine_mode: "add",
branches: [
{type: "dense", output_size: 8, activation: "relu"},
{type: "dense", output_size: 8, activation: "gelu"}
]
},
{type: "lstm", hidden_size: 8},
{type: "rnn", hidden_size: 8}
]
}
Key Features:
- โ Heterogeneous Architectures: LSTM + RNN + Dense ensemble in same layer
- โ Spatial Topology: Explicit 2D/3D grid positioning
- โ Emergent Specialization: Agents learn complementary roles
- โ Trainable: Full gradient flow through all agents
Real-world Applications:
- Multi-robot coordination (heterogeneous robots)
- Hierarchical reinforcement learning
- Multi-agent game playing
- Distributed sensor networks
- Ensemble methods with architectural diversity
๐ง Current Limitations
- CPU Only: WebGPU support coming soon
- No Transformer Inference: Removed in this version (see separate transformer branch)
- 4GB Memory Limit: Standard 32-bit WASM
- Performance: 2-3x slower than native Go
๐ฎ Future Enhancements
- WebGPU acceleration for GPU support
- Memory64 support (unlimited memory)
- Optimize binary size (tree shaking)
- Add transformer inference back
- Web Workers for parallel training
- Quantization support (int8/int4)
- Streaming model loading
- Performance profiling tools
๐ ๏ธ Building
cd wasm
./build_wasm.sh
Output:
main.wasm(6.4MB) - Complete LOOM frameworkwasm_exec.js(17KB) - Go WASM runtime
Requirements:
- Go 1.21+ with WASM support
- Any modern browser with WebAssembly support
๐ Browser Compatibility
Requires WebAssembly support:
- โ Chrome/Edge 57+
- โ Firefox 52+
- โ Safari 11+
Tested on:
- โ Firefox 120+ (Linux)
- โ Chrome 119+ (Linux, macOS, Windows)
๐ License
Apache License 2.0 - see LICENSE file for details.
Transformer Inference API
JavaScript API
// 1. Load tokenizer
const tokenizerData = new Uint8Array(
await (
await fetch("models/SmolLM2-135M-Instruct/tokenizer.json")
).arrayBuffer()
);
const tokResult = JSON.parse(LoadTokenizerFromBytes(tokenizerData));
// 2. Load transformer model
const configData = new Uint8Array(
await (await fetch("models/SmolLM2-135M-Instruct/config.json")).arrayBuffer()
);
const weightsData = new Uint8Array(
await (
await fetch("models/SmolLM2-135M-Instruct/model.safetensors")
).arrayBuffer()
);
const modelResult = JSON.parse(
LoadTransformerFromBytes(configData, weightsData)
);
// 3. Generate text
const result = JSON.parse(GenerateText("Once upon a time", 50, 0.7));
console.log(result.generated_text);
Available Functions
LoadTokenizerFromBytes(tokenizerData)- Load BPE tokenizer from tokenizer.jsonLoadTransformerFromBytes(configData, weightsData)- Load transformer from config + safetensorsEncodeText(text, addSpecialTokens)- Tokenize text to token IDsDecodeTokens(tokenIDs, skipSpecialTokens)- Convert token IDs back to textGenerateText(prompt, maxTokens, temperature)- Generate text (auto-handles tokenization)GenerateNextToken(tokenIDs, temperature)- Generate single next token
Neural Network API
Building
cd wasm
./build.sh
This produces:
loom.wasm(6.0MB) - The compiled WebAssembly binary with transformer supportwasm_exec.js(17KB) - Go's WASM runtime
Running Traditional NN Demos
./serve.sh # Starts server on port 8080
# Open http://localhost:8080/example.html
# Open http://localhost:8080/all_layers_test.html
Demos included:
example.html- Network creation, training, introspectionall_layers_test.html- โจ Load complete models from JSON!inference.html- ๐ Transformer text generation!
โจ Model Loading (The Easy Way)
// Load complete model from JSON string
const network = LoadModelFromString(modelJSONString, "model_id");
// That's it! Network has all layers + weights loaded
// Use it immediately:
const output = JSON.parse(network.ForwardCPU(JSON.stringify([inputData])))[0];
// Train it:
const batches = [{ Input: inputData, Target: targetData }];
const config = {
Epochs: 10,
LearningRate: 0.01,
LossType: "mse",
};
network.Train(JSON.stringify([batches, config]));
// Save it:
const savedJSON = JSON.parse(
network.SaveModelToString(JSON.stringify(["model_id"]))
)[0];
Creating Networks from Scratch
// Create a network: 784 input โ 392 hidden โ 10 output
const network = NewNetwork(784, 1, 1, 2);
// Use registry-based initialization for all layer types
const layer0Config = CallLayerInit(
"InitDenseLayer",
JSON.stringify([784, 392, 0]) // ReLU activation
);
const layer1Config = CallLayerInit(
"InitDenseLayer",
JSON.stringify([392, 10, 1]) // Sigmoid activation
);
// Apply configurations to network
network.SetLayer(JSON.stringify([0, 0, 0, JSON.parse(layer0Config)]));
network.SetLayer(JSON.stringify([0, 0, 1, JSON.parse(layer1Config)]));
Runtime Introspection
Discover all available methods at runtime:
// Get all methods with metadata
const methodsJSON = network.GetMethods();
const methods = JSON.parse(methodsJSON);
console.log(`Network has ${methods.length} methods`);
methods.forEach((method) => {
const params = method.parameters
.map((p) => `${p.name}: ${p.type}`)
.join(", ");
const returns = method.returns.join(", ");
console.log(`${method.method_name}(${params}) -> ${returns}`);
});
// Check if specific method exists
if (network.HasMethod("ForwardCPU")) {
const sig = network.GetMethodSignature(JSON.stringify(["ForwardCPU"]));
console.log("Signature:", sig);
// Output: "ForwardCPU([]float32) ([]float32, time.Duration)"
}
// List all method names
const names = JSON.parse(network.ListMethods());
console.log("Available:", names);
// Output: ["BackwardCPU", "BackwardGPU", "ForwardCPU", "ForwardGPU", ...]
Forward Pass
// Create input (784 values for MNIST-sized input)
const input = new Array(784).fill(0).map(() => Math.random());
// Run forward pass
const resultJSON = network.ForwardCPU(JSON.stringify([input]));
const result = JSON.parse(resultJSON);
const output = result[0]; // Output array (10 values)
const duration = result[1]; // Execution time (nanoseconds)
console.log("Output:", output);
console.log("Time:", duration / 1e6, "ms");
console.log("Mean:", output.reduce((a, b) => a + b) / output.length);
Training Loop
// Manual training (forward passes only)
for (let epoch = 0; epoch < 5; epoch++) {
const input = new Array(784).fill(0).map(() => Math.random());
const resultJSON = network.ForwardCPU(JSON.stringify([input]));
const output = JSON.parse(resultJSON)[0];
// Compute loss (MSE against zero)
const loss = output.reduce((sum, val) => sum + val * val, 0) / output.length;
console.log(`Epoch ${epoch + 1}: Loss = ${loss.toFixed(6)}`);
// For full training, you would:
// 1. Compute gradients from loss
// 2. Call BackwardCPU with gradients
// 3. Update weights (manual or via Train method)
}
Save/Load Models
// Save model to JSON string
const modelJSON = network.SaveModelToString(JSON.stringify(["my_model"]));
const model = JSON.parse(JSON.parse(modelJSON)[0]);
console.log("Model size:", JSON.stringify(model).length, "bytes");
console.log("Layers:", model.models[0].cfg.layers.length);
// Store in localStorage (persistent across page reloads)
localStorage.setItem("loom_model", JSON.stringify(model));
// Load model from localStorage
const savedModel = JSON.parse(localStorage.getItem("loom_model"));
const loadedNetwork = LoadModelFromString(
JSON.stringify(savedModel),
"my_model"
);
// Verify loaded model works
const testInput = new Array(784).fill(0).map(() => Math.random());
const output = JSON.parse(
loadedNetwork.ForwardCPU(JSON.stringify([testInput]))
)[0];
console.log("Loaded model output:", output);
Helper Functions
// Create layer configurations
const denseConfigJSON = InitDenseLayer(784, 128, 0); // 0 = ScaledReLU
const denseConfig = JSON.parse(denseConfigJSON);
// Create a multi-head attention layer
const mhaConfigJSON = InitMultiHeadAttentionLayer(
512, // dModel
8, // numHeads
32, // batchSize
256 // seqLength
);
const mhaConfig = JSON.parse(mhaConfigJSON);
Training Example
// Create network
const network = NewNetwork(784, 1, 1, 2);
// Training parameters
const epochs = 10;
const learningRate = 0.01;
const batchSize = 32;
// Training data (simplified)
const trainData = generateTrainingData(); // Your data
for (let epoch = 0; epoch < epochs; epoch++) {
let totalLoss = 0;
for (let i = 0; i < trainData.length; i += batchSize) {
const batch = trainData.slice(i, i + batchSize);
// Forward pass
batch.forEach((sample) => {
const input = sample.input;
const target = sample.target;
// Forward
const outputJSON = network.ForwardCPU(JSON.stringify([input]));
const output = JSON.parse(outputJSON)[0];
// Compute loss (MSE example)
const loss =
output.reduce(
(sum, val, idx) => sum + Math.pow(val - target[idx], 2),
0
) / output.length;
totalLoss += loss;
// Compute gradient
const gradOutput = output.map(
(val, idx) => (2 * (val - target[idx])) / output.length
);
// Backward
network.BackwardCPU(JSON.stringify([gradOutput]));
});
}
console.log(`Epoch ${epoch + 1}: Loss = ${totalLoss / trainData.length}`);
}
Method Calling Convention
All methods follow this pattern:
// Parameters are passed as JSON array
const params = [param1, param2, param3];
const paramsJSON = JSON.stringify(params);
// Call the method
const resultJSON = network.MethodName(paramsJSON);
// Parse results (returns are also JSON array)
const results = JSON.parse(resultJSON);
const result1 = results[0];
const result2 = results[1];
Type Conversion
JavaScript types are automatically converted to Go types:
| JavaScript Type | Go Type |
|---|---|
number |
int, float32, float64 |
boolean |
bool |
string |
string |
Array |
[]T (slice) |
Object |
map[string]T or struct |
| Custom integers | LayerType, ActivationType (automatic conversion) |
For complex types (structs), pass as objects:
const config = {
Type: 0, // LayerType (automatically converted to Go type)
Activation: 1, // ActivationType
Kernel: [1.0, 2.0], // []float32
Bias: [0.1, 0.2], // []float32
};
network.SetLayer(JSON.stringify([0, 0, 0, config]));
The WASM wrapper handles:
- โ
Nil values: JavaScript
nullโ Gonilfor optional fields - โ
Custom types: Type conversion for enums like
LayerType,ActivationType - โ Nested structs: Recursive conversion of complex objects
- โ Slices: Multi-dimensional arrays properly converted
Demo Results
The included example.html demo successfully demonstrates:
Network Creation
Network: 784 โ 392 โ 10
Layer 0: 307,328 weights initialized
Layer 1: 3,920 weights initialized
Total: 2 layers with real weights
Forward Pass
Input: 784 random values [0, 1]
Output: [0.3456, 0.6710, 0.4669, 0.5165, 0.5758, 0.6556, 0.5595, 0.6136, 0.6537, 0.3036]
Range: [0.3036, 0.6710]
Mean: 0.5362
Training (5 epochs)
Epoch 1/5: Loss = 0.2946
Epoch 2/5: Loss = 0.2401 โ (improving)
Epoch 3/5: Loss = 0.2857
Epoch 4/5: Loss = 0.3392
Epoch 5/5: Loss = 0.3121
Model Serialization
Saved: 1,486 bytes
Loaded: Successfully restored network
Verified: Forward pass produces identical outputs
Current Limitations
- No KV Cache: Transformer generation reprocesses full sequence each token (slow but correct)
- No GPU Support: WebGPU integration coming soon
- CPU Only: All operations run on the CPU
- Performance: WASM is 2-3x slower than native Go (but instant deployment!)
- Memory Limit: 4GB for standard WASM (32-bit addressing)
- Binary Size: 6.0MB (includes full framework + transformer support)
Future Enhancements
- KV Caching for transformers (10-100x speedup)
- WebGPU integration for GPU acceleration
- Memory64 support (unlimited memory when Go supports it)
- Streaming model loading for large models
- Web Workers for parallel training
- Quantization (int8/int4) for smaller models
- Optimize binary size (tree shaking, compression)
- Model visualization tools
- Performance benchmarking tools
Architecture
Transformer Inference
The transformer inference system (wasm/inference.go) provides:
- Byte-based Loading: Load models from
Uint8Array(no file system needed) - Pure Go BPE Tokenizer: Complete tokenizer implementation in
/tokenizerpackage - Safetensors Support: Direct loading of HuggingFace model weights
- Full Sequence Context: Processes entire token sequence for proper attention
- Auto-architecture Detection: Supports LLaMA, GPT-2, and other architectures
Neural Network Architecture
The WASM module uses reflection to automatically wrap all public methods of the nn.Network struct:
- Introspection (
nn/introspection.go): Discovers methods and signatures via reflection - Method Wrapper (
main.go): Dynamically wraps methods for JavaScript calls - Type Conversion (
convertParameter): Converts between JavaScript and Go types - Result Serialization (
serializeResults): Returns results as JSON arrays
This approach means:
- โ Zero manual bindings - new methods automatically available
- โ 24+ methods exposed - all Network methods callable from JavaScript
- โ Type-safe - runtime validation with helpful error messages
- โ Self-documenting - introspection reveals complete API
Examples
Transformer Inference Demo
The inference.html demo showcases transformer text generation in the browser:
Features:
- ๐ Load models from local files (downloaded via
huggingface-cli) - ๐จ Beautiful gradient UI with model selection cards
- โก Real-time text generation with progress tracking
- ๐ Live statistics (tokens/sec, time elapsed)
- ๐ง Adjustable temperature and max tokens
Example Output (SmolLM2-135M-Instruct):
Prompt: "Once upon a time"
Generated: "hi
I'm excited to see what you come up with! Let me know if you have any"
Traditional Neural Network Demos
See example.html for a complete interactive demo including:
- โ Network creation with layer initialization
- โ Method introspection (24 methods discovered)
- โ Forward/backward passes with real outputs
- โ Model save/load with localStorage
- โ Training workflow with loss tracking
- โ Console logging for debugging
Try it now:
# Transformer inference:
cd wasm
bash serve_wasm.sh
# Open http://localhost:8888/inference.html
# Traditional neural networks:
cd wasm
./serve.sh
# Open http://localhost:8080/example.html
# Open http://localhost:8080/all_layers_test.html
Browser Compatibility
Requires WebAssembly support (all modern browsers):
- Chrome/Edge 57+
- Firefox 52+
- Safari 11+
Tested and working on:
- โ Firefox 120+ (Linux)
- โ Chrome 119+ (Linux, macOS, Windows)
License
Apache License 2.0 - see LICENSE file for details.