Skip to main content

Experiment 000: One-Layer Global Throttle Sanity Check

Status: Valid

Notebook: workspace/ablations/000_global_throttle_sanity/notebooks/affine_drift_controler_no_quant.ipynb

Purpose

This is the first working sanity check for the dynamic global throttle idea.

The goal is not yet to test quantization, fixed-point rails, multiple layers, or the full ENABOL hardware path. The goal is narrower:

Can we build a custom Keras training loop that detects unstable online learning dynamics and globally throttles the learning update so training does not diverge?

This test validates the instrumentation and controller on the simplest possible model.

Dataset

We start with a linear teacher with no bias first:

xU([0,1]din)with din=4,dout=2.x \sim \mathcal{U}([0,1]^{d_{\text{in}}}) \qquad \text{with } \qquad d_{\text{in}}=4, \quad d_{\text{out}}=2.

Then the output is generated by a linear teacher:

y=Axy = A x

We use this exact teacher matrix:

A=[1.250.750.500.200.400.901.100.60].A = \begin{bmatrix} 1.25 & -0.75 & 0.50 & 0.20 \\ -0.40 & 0.90 & 1.10 & -0.60 \end{bmatrix}.

Such that

[y1y2]=[1.250.750.500.200.400.901.100.60][x1x2x3x4].\begin{bmatrix} y_1\\y_2 \end{bmatrix} = \begin{bmatrix} 1.25 & -0.75 & 0.50 & 0.20 \\ -0.40 & 0.90 & 1.10 & -0.60 \end{bmatrix}\begin{bmatrix} x_1\\x_2\\x_3\\x_4 \end{bmatrix}.

The following image shows scatter plots of the input xx and target yy distributions, as well as their histograms at the bottom.

Dataset Dataset histograms

Code to generate the dataset:

If you need to understand the structure of this dataset, check here.

import kappa
# Create the dataset
dataset = kappa.AffineDataset(num_samples=1000, use_bias=False)
# Plot it
dataset.plot()
dataset.plot_histogram()
print(dataset)

# Get the data
X, Y = dataset.get()

which renders:

AffineDataset(
[Input] X:
Shape: (1000, 4)
Dtype: DataType.D2_FLOAT
X <- Uniform(-1, 1)
[Output] Y:
Shape: (1000, 2)
Dtype: DataType.D2_FLOAT
Y <- X @ A.T + b
---------
A = [[ 1.25 -0.75 0.5 0.2 ]
[-0.4 0.9 1.1 -0.6 ]]
b = [0. 0.]
----------
Analytic Hessian:
Lambda max: 1.0841
Eta max: 1.8448
)

Note that the dataset object is also giving the analytical hessian metrics. This is important because it means that:

λmax(Hnom)1.0841ηmaxnom=2λmax(Hnom)1.8448.\begin{aligned} \lambda_{\max}(H_{\text{nom}}) &\approx 1.0841 \\ \eta_{\max}^{\text{nom}}=\frac{2}{\lambda_{\max}(H_{\text{nom}})} &\approx 1.8448. \end{aligned}

Drift Model

Now let's assume a learning rate of η=0.5\eta=0.5. With no drift, the nominal margin is:

ηλmax(Hnom)=1.08410.50.54<2.\eta\lambda_{\max}(H_{\text{nom}}) = 1.0841 \cdot 0.5 \approx 0.54 < 2.

However, with a gain drift of γ=4\gamma=4, the Hessian grows approximately as:

λmax(Hdrift)γ2λmax(Hnom)17.3456.\lambda_{\max}(H_{\text{drift}}) \approx \gamma^2\lambda_{\max}(H_{\text{nom}}) \approx 17.3456.

which means the post-drift margin is:

ηλmax(Hdrift)8.6728>2\eta\lambda_{\max}(H_{\text{drift}}) \approx 8.6728 > 2

When drifting the input gain by a factor of 4, the previously stable learning rate of 0.50.5 becomes unstable. This is the regime where we expect the global throttle controller to intervene and prevent divergence.

Controller Behavior

Given the instability introduced by the drift, what we basically expect is that the controller should choose a throttle αt\alpha_t such that the effective learning rate αtη\alpha_t \eta is back in the stable region. In other words, we expect:

αtχηCtctrl.\alpha_t\approx \frac{\chi}{\eta C_t^{\text{ctrl}}}.

With χ=1.5\chi=1.5 the ideal post-drift throttle:

α1.58.67280.173.\alpha^\star \approx \frac{1.5}{8.6728} \approx 0.173.

Then:

αtηλmax(Hdrift)1.5<2.\alpha_t\eta\lambda_{\max}(H_{\text{drift}}) \approx 1.5 < 2.

The global throttle reduces the optimal learning rate such that it stays in the stable region even after drift.

Model

The student model is a one-layer linear network without bias:

y^=Wx\hat{y} = Wx

This keeps the Hessian and closed-loop stability story simple.

Code to build the model:

model = kappa.LinearBlockModel(dataset=dataset, num_hidden=[dataset.A.shape[0]],
activation=None, use_batchnorm=False, verbose=True,
use_bias=False, seed=0)
model.summary()

which returns

[INFO] - Building model with input shape (4,) and output shape (2,)
[INFO] - Added Dense layer with 2 units
Model: "LinearBlockModel"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer) │ (None, 4) │ 0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense) │ (None, 2) │ 8 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 8 (32.00 B)
Trainable params: 8 (32.00 B)
Non-trainable params: 0 (0.00 B)

Current Notebook Flow

The notebook currently does three things:

  1. Builds a controlled affine dataset.
  2. Trains the one-layer model normally to confirm the task is learnable.
  3. Reinitializes the model, performs a nominal warmup, then enters a drifted online phase:
xdrift=γxx_{\mathrm{drift}} = \gamma x

The notebook compares:

RunControllerPurpose
Nominal baselineoffConfirm the one-layer model learns the teacher.
Drift baselineoffShow unstable behavior after gain drift.
Drift controlledonShow the global throttle can prevent divergence.

What Is Being Tested

The custom trainer logs the quantities needed for closed-loop analysis:

  • loss,
  • RMSE,
  • weight error,
  • parameter norm,
  • gradient norm,
  • raw update norm,
  • actual update norm,
  • curvature proxy,
  • curvature EMA,
  • controller value alpha(t),
  • effective learning rate,
  • Hessian eigenvalue estimates,
  • stability margins,
  • update-map spectral radius,
  • finite/divergence flags.

The controller globally scales the SGD update:

Δθactual(t)=α(t)Δθraw(t)\Delta \theta_{\mathrm{actual}}(t) = \alpha(t)\Delta \theta_{\mathrm{raw}}(t)

where:

0<α(t)10 < \alpha(t) \le 1

The important property is that this scaling should preserve the update direction while reducing the effective learning rate.

Experiment 000A: Sanity check without drift nor quantization

Here we just make sure that our dataset and our model classes/objects are working as expected so we simply train them with a sane value for the learning rate, and without any drift. The code for this experiment:

h = model.train_instrumented(
X,
Y,
epochs=100,
batch_size=32, #dataset.num_samples, # Full-batch for clean Hessian metrics
learning_rate=0.05,
loss_mode="half_mse",
curvature_ema_rho=0.05,
chi=1.5,
use_controller=False,
compute_analytic_hessian=True,
reference_A=dataset.reference_weight_matrix,
)
print(h)
h.plot_results(title="Training History")

The result is shown in the image below, which confirms that the loss curve is smooth and converges to zero, as expected, without any instability or divergence.

Loss curve

Experiment 000B: Sanity check with drift but no controller

Here we introduce a gain drift in the input:

x=γxy=Axx' = \gamma x \qquad \leadsto \qquad y' = Ax'

so that the task remains consistent while the Hessian grows approximately as:

λmax(Hdrift)γ2λmax(Hnom)\lambda_{\max}(H_{\mathrm{drift}}) \approx \gamma^2 \lambda_{\max}(H_{\mathrm{nom}})

Stage 0: Normal training without drift

First, we reintialize the model and run a nominal warmup phase to confirm that training starts in a stable regime. The code to generate this is:

# Reinit
model.reinitialize_weights()

# Phase 1: nominal warmup
h_nom = model.train_instrumented(
X,
Y,
epochs=20,
batch_size=len(X),
learning_rate=0.5,
use_controller=False,
reference_A=dataset.A,
)
h_nom.plot_results(title="Warmup History")

The result is shown in the image below, which confirms that the model learns the task and converges to zero loss in a stable way. Warmup loss curve

Stage 1: drifted training without controller

Next, we introduce the gain drift and continue training without the controller to confirm that the training dynamics become unstable. The code to generate this is:

# Phase 2: sensor gain drift
gamma = 4.0
Xd = gamma * X
Yd = Y # important: target remains clean

h_drift = model.train_instrumented(
Xd,
Yd,
epochs=100,
batch_size=len(Xd),
learning_rate=0.5,
use_controller=False,
reference_A=dataset.A / gamma,
)

h_drift.plot_results(title="Drift History")

The result is shown in the image below, which confirms that the loss curve diverges after the drift is introduced, as expected. Drift loss curve

Experiment 000C: Sanity check with drift and controller

Finally, we run the same loop with the drifted data but now we turn on the controller to confirm that it can prevent divergence. The code to generate this is below (note that it's important that we reinitialize the model again to start from the same initial conditions as the previous runs, and that we use the same nominal warmup phase to give the controller a chance to estimate the curvature before the drift starts):

# Reinit
model.reinitialize_weights()

# Phase 1: nominal warmup
h_nom = model.train_instrumented(
X,
Y,
epochs=20,
batch_size=len(X),
learning_rate=0.5,
use_controller=False,
reference_A=dataset.A,
)

# Phase 2: sensor gain drift
h_drift = model.train_instrumented(
Xd,
Yd,
epochs=100,
batch_size=len(Xd),
learning_rate=0.5,
use_controller=True,
reference_A=dataset.A / gamma,
)

h_nom.plot_results(title="Warmup History")
h_drift.plot_results(title="Drift History")

The result is shown in the image below, which confirms that the loss curve remains stable and converges to zero even after the drift is introduced, thanks to the global throttle controller. Drift with controller loss curve

Summary of Results

Key Findings: This preliminary notebook supports the basic claim that a global controller can throttle the total learning rate and stabilize a one-layer online learning loop.

Notebook Preview

The notebook can be viewed directly on GitHub:

Open notebook on GitHub

If the iframe does not load, use the GitHub link above.