Models / NN Module Reference
Source: kappa/nn.py
This page documents the model and training-loop code block by block. It is written for people extending the ablation harness.
Imports
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, Dict, Mapping, Union, Optional, Tuple
import numpy as np
import tensorflow as tf
Local dependencies:
from .dataset import BaseDataset
from .history import FitHistory
from .precision import PrecisionDict, ensure_precision_dict
from .quantization import quantize_tensor, rail_stats
from .utils import ...
Important utilities:
| Utility | Purpose |
|---|---|
flatten_tensors | Convert trainable variables or gradients into one global vector. |
half_mse_batch_loss | Loss used by analytic Hessian experiments. |
tensor_l2_norm | Stable tensor norm for logging. |
safe_cosine | Update direction diagnostic. |
analytic_single_dense_hessian | Exact Hessian for one-layer no-bias regression. |
stability_metrics_from_hessian | Computes eta * lambda_max and update-map spectral radius. |
BaseModel
@dataclass(eq=False)
class BaseModel(ABC):
dataset: BaseDataset
loss: Union[str, tf.keras.losses.Loss] = "mse"
optimizer: Union[str, tf.keras.optimizers.Optimizer] = "sgd"
metrics: list = field(default_factory=lambda: [])
model: Optional[tf.keras.Model] = field(init=False, default=None)
input_shape: Optional[Tuple[int, ...]] = field(init=False, default=None)
output_shape: Optional[Tuple[int, ...]] = field(init=False, default=None)
name: str = "BaseModel"
verbose: bool = False
seed: Optional[int] = None
BaseModel owns the dataset reference, the Keras model object, and the custom training loop.
Subclasses are responsible for constructing self.model.
__post_init__()
def __post_init__(self):
self.input_shape = self.dataset.input_shape[1:]
self.output_shape = self.dataset.output_shape[1:]
Removes the batch dimension from dataset shapes.
Example:
dataset.X.shape == (1000, 4)
model.input_shape == (4,)
_compile(...)
def _compile(self, optimizer=None, loss=None, metrics=None, **kwargs):
...
Thin wrapper over self.model.compile(...).
Used by the simple train() method. The instrumented ablation loop does not use Keras optimizer application because it needs direct control over every update.
summary()
def summary(self) -> None:
...
Prints Keras model summary if the model exists.
Usage:
model.summary()
reinitialize_weights()
def reinitialize_weights(self):
...
Loops over Keras layers and reassigns kernels/biases using the layer initializers.
Usage:
model.reinitialize_weights()
Used before comparing baseline and controller runs so they start from comparable initial conditions.
train(...)
def train(self, X, Y, epochs=10, batch_size=32) -> dict[str, np.ndarray]:
...
Simple Keras-style training loop:
- Builds a
tf.data.Dataset. - Calls
_compile(). - Uses the compiled optimizer and loss.
- Applies gradients with
optimizer.apply_gradients. - Returns only loss history.
This is mostly a convenience method. Ablation work should use train_instrumented().
train_instrumented(...)
def train_instrumented(
self,
X,
Y,
epochs=10,
batch_size=32,
learning_rate=0.05,
shuffle=True,
loss_mode="half_mse",
curvature_ema_rho=0.05,
chi=1.5,
eps=1e-12,
use_controller=False,
compute_analytic_hessian=True,
reference_A=None,
precision_dict=None,
) -> FitHistory:
...
This is the main ablation loop.
Inputs
| Argument | Purpose |
|---|---|
X, Y | NumPy training arrays. |
learning_rate | Base SGD learning rate eta. |
loss_mode | "half_mse" for Hessian-clean experiments or "keras_mse". |
curvature_ema_rho | EMA smoothing factor for curvature proxy. |
chi | Target stability margin for the throttle. |
use_controller | If true, applies alpha_t. If false, only logs would-be alpha_t. |
reference_A | Teacher matrix for one-layer weight error. |
precision_dict | Optional PrecisionDict; None means full floating point. |
Setup Block
The method:
- Converts
precision_dictusingensure_precision_dict. - Validates precision names against the Keras model.
- Converts
XandYtonp.float32. - Selects the loss function.
- Creates a batched
tf.data.Dataset. - Allocates the
historydictionary.
Backward compatibility rule:
precision_dict=None
must preserve the original Experiment 000 behavior.
Main Step Block
At each batch:
theta_before = flatten(model.trainable_variables)
with GradientTape:
y_pred = forward(...)
loss = loss_fn(y_batch, y_pred)
grads = tape.gradient(loss, trainable_vars)
grad_flat = flatten(grads)
If precision_dict is present, the forward path uses _forward_with_precision() and gradients are passed through _quantize_gradients().
Curvature Proxy Block
dG = grad_flat - prev_grad
dtheta = theta_before - prev_theta
C = ||dG|| / (||dtheta|| + eps)
This estimates directional update-field sensitivity.
EMA:
curvature_ema = (1-rho) * curvature_ema + rho * curvature_proxy
Control signal:
curvature_for_control = max(curvature_proxy, curvature_ema)
Controller Block
alpha_would = min(1.0, chi / (eta * (curvature_for_control + eps)))
alpha = alpha_would if use_controller else 1.0
eta_eff = alpha * eta
If use_controller=False, the loop still logs alpha_would so we can see whether the controller would have intervened.
Update Block
Floating-point path:
var.assign_sub(eta_eff * grad)
Quantized path:
delta = -eta_eff * grad
delta = quantize_tensor(delta, update_dtype, ste=False)
var.assign_add(delta)
self._quantize_variable_storage(var, precision)
This is where update quantization and stored-weight quantization happen.
Metric Block
The loop logs:
loss
rmse
theta_norm
grad_norm
raw_update_norm
actual_update_norm
update_cosine
update_angle_rad
update_radius_ratio
curvature_proxy
curvature_ema
alpha
alpha_would
eta_eff
forward_gain_spectral
hessian_lambda_max
stability_margin_lambda_raw
stability_margin_lambda_ctrl
spectral_radius_raw
spectral_radius_ctrl
weight_error_fro
finite/divergence flags
With quantization enabled, it also logs:
weight_saturation_fraction_max
weight_near_rail_fraction_max
gradient_saturation_fraction_max
gradient_near_rail_fraction_max
update_saturation_fraction_max
update_underflow_fraction_max
Update Geometry Diagnostics
Two metrics are especially important for the quantized global-throttle ablation:
actual_update_norm
update_cosine
update_angle_rad
update_radius_ratio
raw_update_norm is the norm of the intended update before storage effects:
actual_update_norm is computed after the update has been applied and after any quantized variable storage has been enforced:
This catches silent learning death. A run can look numerically stable because the applied update has underflowed to zero; that is not a successful controller result.
update_cosine compares the applied update to the intended update:
For pure global throttling, this should stay close to 1 because the throttle scales the full update vector uniformly. If quantization, clipping, or a future row/column projection changes the update direction, this cosine should fall.
The phase-distortion diagnostics are derived from the same quantities:
update_angle_rad stores beta_t; update_radius_ratio stores r_t.
These are offline software diagnostics. They are meant to visualize and compare numerical update distortion during ablations. They are not currently proposed as hardware controller inputs because the high-precision raw update may not exist as a physical hardware signal.
Return Value
return FitHistory(**{k: np.asarray(v) for k, v in history.items()})
Use:
h = model.train_instrumented(...)
h.plot_results()
_forward_with_precision(...)
def _forward_with_precision(self, x, precision, *, training):
...
Manual forward path with fake quantization hooks.
Current behavior:
- Quantize input with
precision.dtype("input", "value"). - For each Dense layer:
- quantize kernel as
"weight", - matrix multiply,
- quantize dot-product as
"accumulator", - add quantized bias if present,
- quantize layer output as
"activation".
- quantize kernel as
- For non-Dense layers:
- call layer normally,
- quantize output as
"activation".
The quantizers use STE inside GradientTape, so the forward value is quantized but gradients still flow.
Extension point:
If Conv layers are added, implement their explicit forward path here.
_quantize_gradients(...)
def _quantize_gradients(self, grads, trainable_vars, precision):
...
For each gradient:
- Find the owning layer with
_layer_and_field_for_variable. - Lookup
precision.dtype(layer_name, "gradient"). - Quantize the gradient if a dtype exists.
Missing gradient dtype means floating point gradient.
_quantize_variable_storage(...)
def _quantize_variable_storage(self, var, precision):
...
After the update is applied, this quantizes the stored variable:
- Dense kernel uses
"weight", - Dense bias uses
"bias", - other trainable variables use
"value".
This models fixed-point storage.
_layer_and_field_for_variable(...)
def _layer_and_field_for_variable(self, var) -> tuple[str, str]:
...
Maps a Keras variable to:
(layer_name, field_name)
Examples:
dense0/kernel -> ("dense0", "weight")
dense0/bias -> ("dense0", "bias")
This function is central to PrecisionDict integration. If a new layer type has special trainable variables, update this mapping.
_same_variable(a, b)
@staticmethod
def _same_variable(a, b) -> bool:
...
Defensive Keras variable comparison helper. It first checks identity, then tries path, then falls back to name.
This avoids fragile behavior across Keras/TensorFlow versions.
_rail_max_for_variables(...)
def _rail_max_for_variables(self, vars_, precision, *, fields):
...
Computes max saturation and near-rail fractions across trainable variables for selected fields.
Used for:
weight_saturation_fraction_max
weight_near_rail_fraction_max
_rail_max_for_tensors(...)
def _rail_max_for_tensors(self, tensors, trainable_vars, precision, *, field):
...
Computes max rail pressure across non-variable tensors, currently gradients.
Used for:
gradient_saturation_fraction_max
gradient_near_rail_fraction_max
LinearBlockModel
@dataclass
class LinearBlockModel(BaseModel):
num_hidden: list = field(default_factory=lambda: [64, 64])
activation: Optional[Union[str, tf.keras.layers.Activation]] = None
use_batchnorm: bool = False
use_bias: bool = True
name: str = "LinearBlockModel"
Dense-block model used by the first ablations.
Each block is:
Dense -> optional Activation -> optional BatchNorm
LinearBlockModel.__post_init__()
def __post_init__(self):
super().__post_init__()
self.model = self._build_model(...)
Builds the Keras model immediately after dataclass initialization.
_build_model(input_shape, output_shape, verbose=True)
def _build_model(self, input_shape, output_shape, verbose=True) -> tf.keras.Model:
...
Build process:
- Create Keras input named
model_input. - Add Dense blocks using stable names:
dense0
dense1
...
- Add optional activations:
activation0
activation1
...
- Add optional batchnorm:
batchnorm0
batchnorm1
...
- If the last Dense output dimension does not match the dataset output, add a final Dense layer.
Usage:
model = kappa.LinearBlockModel(
dataset=dataset,
num_hidden=[8, 2],
activation="relu",
use_batchnorm=False,
use_bias=True,
)
Extension Checklist
When adding a new model class:
- Subclass
BaseModel. - Build
self.modelin__post_init__(). - Use stable layer names.
- Avoid reserved layer names
inputandloss. - If the model has non-Dense trainable variables, update
_layer_and_field_for_variable. - If fake quantization must happen inside special ops, update
_forward_with_precision.
When adding a new optimizer:
- Keep the flattened global update available for logging.
- Log intended update and actual applied update.
- Keep
update_cosinemeaningful. - Add precision fields for optimizer state, such as
momentum,adam_m, oradam_v.