Even Smaller Models of Superposition#
Introduction: The superposition problem#
Anthropic’s 2022 paper on toy models of superposition makes a counterintuitive claim: if you have two statistically independent variables, you can linearly combine them into a single number in a way that allows you to reconstruct both variables separately later.
For someone like me with a background in linear algebra and information theory, this seems impossible. If you project points from a two-dimensional space onto a one-dimensional line, you end up with a representation that only has one degree of freedom. There is no way that you can recover two separate dimensions from there.
The trick, it turns out, involves three ingredients: sparsity (most values are zero most of the time), positivity (values lie in a known range like [0,1]), and adding a simple nonlinearity (specifically, something as simple as a ReLU activation). Together, these properties allow us to do something that seems like it violates basic principles of linear algebra.
In this notebook, we’ll build up an understanding of how this works through a series of experiments with progressively more structure.
Data generation#
Our data consists of pairs of values \(x = (x_1, x_2)\), which we’ll call features. Each feature has three key properties:
Uniform distribution: Both features are uniformly distributed in the interval \([0, 1]\).
Statistical independence: The value of \(x_1\) tells us nothing about the value of \(x_2\), and vice versa.
Sparsity: With probability \(s\) (the sparsity parameter), each feature is set to exactly zero. This happens independently for each feature.
For now, we’ll start with \(s = 0\), meaning no sparsity—all points lie in the interior of the unit square.
We’ll also introduce the concept of importances: we can weight how much we care about accurately reconstructing each feature. If one feature has higher importance, we’ll penalize reconstruction errors on that feature more heavily in our loss function. This will be useful for understanding the trade-offs the model makes when it can’t perfectly reconstruct both dimensions.
import torch as t
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "notebook_connected"
DEF_N_POINTS = 1000
DEF_N_FEATURES = 2
DEF_SPARSITY = 0
def generate_data(n_points=DEF_N_POINTS, n_features=DEF_N_FEATURES, sparsity=DEF_SPARSITY):
# Points are unuformly distributed in [0,1]
# Independently across features, x_i = 0 with probability S
sparsity_mask = t.rand(n_points, n_features) < sparsity
x = t.rand(n_points, n_features)
x[sparsity_mask] = 0
return x
def plot_data(x, title=None):
fig = px.scatter(x=x[:, 0], y=x[:, 1], opacity=0.5)
fig.update_layout(width=600, height=600, title=title)
fig.update_xaxes(title_text="x<sub>1</sub>")
fig.update_yaxes(title_text="x<sub>2</sub>")
fig.show()
data = generate_data()
plot_data(data, title="Initial data distribution")
Why this should be impossible#
Look at the scatter plot above. The points fill a two-dimensional region. If we project these points onto any one-dimensional line, we collapse them down to a single axis. Then, when we try to reconstruct the original points, we can only place them somewhere along that same line.
A line embedded in a two-dimensional space only has one degree of freedom. We can slide points back and forth along it, but we can’t move them perpendicular to it. This means that the reconstructed points will form a one-dimensional structure, no matter how we choose the line.
The original data, by contrast, has two degrees of freedom—it occupies a full two-dimensional region. There’s no way to perfectly reconstruct a two-dimensional distribution from a one-dimensional projection using only linear operations. Information has been destroyed in the compression step, and linear operations can’t create information out of nothing.
Or so it seems.
Visualizing the transformation#
The plot below shows what this transformation looks like for a specific choice of \(W\) and \(b\). The visualization has three panels:
Left panel: The input points \(x\) in two-dimensional space.
Middle panel: The hidden dimension \(h = Wx\), which compresses the input down to a single number. Each input point is mapped to a position on the one-dimensional axis.
Right panel: The reconstructed points \(y = W^T h + b\) back in two-dimensional space.
In the left and right panels, the arrows show the parameters of the transformation. The black arrow shows the direction of \(W\), which determines how we project onto the line. The blue arrow shows \(b\), which determines where the line is centered in the output space. Points are colored by their hidden value \(h\), making it easier to see how the compression and reconstruction work.
Show code cell source
from plotly.subplots import make_subplots
import plotly.graph_objects as go
def plot_mapping(x: t.Tensor, h: t.Tensor, y: t.Tensor, W: t.Tensor, b: t.Tensor, title: str = None, margin: float = 0.1):
mse = ((x - y) ** 2).mean().item()
h_flat = h.numpy().flatten()
w_vec = W.squeeze().numpy()
w_str = f"[{w_vec[0]:.2f}, {w_vec[1]:.2f}]"
b_vec = b.numpy()
b_str = f"[{b_vec[0]:.2f}, {b_vec[1]:.2f}]"
subtitle = f"W = {w_str}, b = {b_str}, MSE = {mse:.4f}"
x_np = x.numpy()
y_np = y.numpy()
x_range_0 = x_np[:, 0].max() - x_np[:, 0].min()
x_range_1 = x_np[:, 1].max() - x_np[:, 1].min()
x_lim_0 = [x_np[:, 0].min() - margin * x_range_0, x_np[:, 0].max() + margin * x_range_0]
x_lim_1 = [x_np[:, 1].min() - margin * x_range_1, x_np[:, 1].max() + margin * x_range_1]
y_range_0 = y_np[:, 0].max() - y_np[:, 0].min()
y_range_1 = y_np[:, 1].max() - y_np[:, 1].min()
y_lim_0 = [y_np[:, 0].min() - margin * y_range_0, y_np[:, 0].max() + margin * y_range_0]
y_lim_1 = [y_np[:, 1].min() - margin * y_range_1, y_np[:, 1].max() + margin * y_range_1]
fig = make_subplots(
rows=1, cols=3,
subplot_titles=(
"Input Points (X)",
"Hidden Dimension (H)",
"Output Points (Y)"
),
specs=[[{"type": "scatter"}, {"type": "scatter"}, {"type": "scatter"}]],
horizontal_spacing=0.1
)
fig.add_trace(
go.Scatter(
x=x[:, 0].numpy(),
y=x[:, 1].numpy(),
mode='markers',
marker=dict(
opacity=0.7,
size=4,
color=h_flat,
colorscale='Viridis',
showscale=False
),
name='Input'
),
row=1, col=1
)
w_normalized = w_vec / (t.norm(W).item() + 1e-8)
arrow_scale = 0.3
fig.add_annotation(
x=b_vec[0] + w_normalized[0] * arrow_scale,
y=b_vec[1] + w_normalized[1] * arrow_scale,
ax=b_vec[0],
ay=b_vec[1],
xref='x1',
yref='y1',
axref='x1',
ayref='y1',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1.5,
arrowcolor='black',
row=1, col=1
)
fig.add_annotation(
x=b_vec[0] + w_normalized[0] * arrow_scale,
y=b_vec[1] + w_normalized[1] * arrow_scale,
text=f"W = {w_str}",
xref='x1',
yref='y1',
showarrow=False,
xanchor='left',
yanchor='bottom',
xshift=10,
yshift=5,
font=dict(size=10),
row=1, col=1
)
fig.add_trace(
go.Scatter(
x=h_flat,
y=t.zeros_like(h).numpy().flatten(),
mode='markers',
marker=dict(
opacity=0.7,
size=4,
color=h_flat,
colorscale='Viridis',
showscale=False
),
name='Hidden'
),
row=1, col=2
)
fig.add_trace(
go.Scatter(
x=y[:, 0].numpy(),
y=y[:, 1].numpy(),
mode='markers',
marker=dict(
opacity=0.7,
size=4,
color=h_flat,
colorscale='Viridis',
showscale=True,
colorbar=dict(title="h")
),
name='Output'
),
row=1, col=3
)
fig.add_annotation(
x=b_vec[0],
y=b_vec[1],
ax=0,
ay=0,
xref='x3',
yref='y3',
axref='x3',
ayref='y3',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1.5,
arrowcolor='blue',
row=1, col=3
)
fig.add_annotation(
x=b_vec[0],
y=b_vec[1],
text=f"b = {b_str}",
xref='x3',
yref='y3',
showarrow=False,
xanchor='left',
yanchor='bottom',
xshift=10,
yshift=5,
font=dict(size=10),
row=1, col=3
)
fig.add_annotation(
x=b_vec[0] + w_normalized[0] * arrow_scale,
y=b_vec[1] + w_normalized[1] * arrow_scale,
ax=b_vec[0],
ay=b_vec[1],
xref='x3',
yref='y3',
axref='x3',
ayref='y3',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1.5,
arrowcolor='black',
row=1, col=3
)
fig.add_annotation(
x=b_vec[0] + w_normalized[0] * arrow_scale,
y=b_vec[1] + w_normalized[1] * arrow_scale,
text=f"W = {w_str}",
xref='x3',
yref='y3',
showarrow=False,
xanchor='left',
yanchor='bottom',
xshift=10,
yshift=5,
font=dict(size=10),
row=1, col=3
)
fig.update_xaxes(title_text="x<sub>1</sub>", range=x_lim_0, row=1, col=1)
fig.update_yaxes(title_text="x<sub>2</sub>", range=x_lim_1, row=1, col=1)
fig.update_xaxes(title_text="h", row=1, col=2)
fig.update_yaxes(showticklabels=False, row=1, col=2)
fig.update_xaxes(title_text="y<sub>1</sub>", range=y_lim_0, row=1, col=3)
fig.update_yaxes(title_text="y<sub>2</sub>", range=y_lim_1, row=1, col=3)
main_title = title if title else "Mapping Visualization"
full_title = f"{main_title}<br><sub>{subtitle}</sub><br><br>"
fig.update_layout(
height=500,
width=900,
showlegend=False,
title_text=full_title
)
fig.show()
import einops
W = t.tensor([[1.0,-1.0]], dtype=t.float32)
b = t.tensor([0.5, 0.5], dtype=t.float32)
def f(x):
return einops.einsum(x, W, "b f, h f -> b h")
def g(h):
return einops.einsum(h, W, "b h, h f -> b f") + b
x = generate_data()
h = f(x)
y = g(h)
plot_mapping(x, h, y, W, b)
To put it formally:
We have a dataset of 2-dimensional input points \(D: x \in \mathbb{R}^{2}\). We will denote each of the two components of a pair \(x\) as “features” \((x_1, x_2)\).
The points \(x\) are uniformly distributed between \([0,1]\) independently across both dimensions.
We want to find a mapping \(f: \mathbb{R}^{2} \rightarrow \mathbb{R}\) that gives us a single “hidden” value \(h = f(x)\) for each input pair \(x\).
The mapping \(f\) should have a “reverse mapping” \(g: \mathbb{R} \rightarrow \mathbb{R}^{2}\) that gives us an output pair \(y \in \mathbb{R}^{2}\) for each hidden value \(h\).
The mappings \(f\) and \(g\) should be chosen such that they minimize the mean square error between \(x\) and \(y\), i.e. \(L_{MSE} = \sum_{D}(x-y)^2 = \sum_{D}(x-g(f(x)))^2\)
Learning a linear model#
We’ll use PyTorch to learn the parameters \(W\) and \(b\) via gradient descent. The objective is to minimize the mean squared error between the input \(x\) and the reconstruction \(y\):
The importances allow us to weight the reconstruction error differently for each feature. When both importances are equal, we care equally about reconstructing both dimensions. When one importance is higher, we’re willing to accept larger errors on the less important dimension if it means better reconstruction of the more important one.
The model#
The LinearMapping class implements the compression and reconstruction we described earlier:
compute_hidden(x): Computes \(h = Wx\), projecting the input onto the one-dimensional hidden space.compute_output(h): Computes \(y = W^T h + b\), reconstructing from the hidden dimension back to two dimensions.forward(x): Combines both operations to map from input \(x\) to reconstruction \(y\).
The model is entirely linear—just matrix multiplication and addition, with no nonlinearities anywhere.
Side note: why is one of the two mappings affine but not both?
You might be wondering why \(f\) is linear and \(g\) is affine, that is, why only \(g\) has a bias term. This is because a composition of two affine functions can be rewritten as a composition of one linear and one affine function, with both translations being factored out into a single final translation. Compare both cases below:
Case 1: one linear function and one affine function:
\(h = f(x) = Wx\)
\(y = g(h) = W^{T}h + b\)
Composition: \(y = g(f(x)) = W^{T}Wx + b\)
Case 2: two affine functions:
\(h = f(x) = Wx + b_h\)
\(y = g(h) = W^{T}h + b_y\)
Composition: \(y = g(f(x)) = W^{T}(Wx + b_h) + b_y = W^{T}Wx + W^{T}b_h + b_y\)
We can rewrite \(W^{T}b_h + b_y\), which is a constant vector in \(\mathbb{R}^2\), as the combined bias \(b_c\) for convenience.
This leaves us with \(y = W^{T}Wx + b_c\).
Both cases have exactly the same form—any bias in the first mapping can be absorbed as a constant term in the bias of the second mapping.
Show code cell content
from torch import nn
from tqdm.notebook import tqdm
def get_device():
if t.backends.mps.is_available():
return t.device("mps")
elif t.cuda.is_available():
return t.device("cuda")
else:
return t.device("cpu")
DEVICE = get_device()
from torch import nn
DEF_INPUT_DIM = 2
DEF_HIDDEN_DIM = 1
class LinearMapping(nn.Module):
def __init__(self, input_dim: DEF_INPUT_DIM, hidden_dim: DEF_HIDDEN_DIM):
super().__init__()
W = t.empty(hidden_dim, input_dim)
b = t.empty(input_dim)
nn.init.xavier_normal_(W)
nn.init.zeros_(b)
self.W = nn.Parameter(W)
self.b = nn.Parameter(b)
def compute_hidden(self, x):
return einops.einsum(x, self.W, "b f, h f -> b h")
def compute_output(self, h):
return einops.einsum(h, self.W, "b h, h f -> b f") + self.b
def forward(self, x):
h = self.compute_hidden(x)
y = self.compute_output(h)
return y
The training loop#
The train function uses standard gradient descent with the AdamW optimizer. It splits the data into training and evaluation sets, then iteratively updates \(W\) and \(b\) to minimize the weighted mean squared error.
Throughout training, we track several metrics: the training loss, the parameters \(W\) and \(b\), and samples of the input and reconstructed points. The visualization in the next cells will show an animation of how these evolve during training, with the left panel showing the transformation in 2D space and the right panel showing the loss curves.
from torch import nn
from tqdm.notebook import tqdm
DEF_IMPORTANCES = t.ones(DEF_INPUT_DIM)
DEF_LR = 1e-4
DEF_NUM_EPOCHS = 1000
def compute_loss(x, y, importances=DEF_IMPORTANCES):
return ((x - y) ** 2 * importances).mean()
def train(model, x, train_eval_split=0.2, importances=DEF_IMPORTANCES, lr=DEF_LR, num_epochs=DEF_NUM_EPOCHS, sample_fraction=0.3):
indices = t.randperm(len(x))
x_shuffled = x[indices]
n_train = int(len(x) * (1 - train_eval_split))
x_train = x_shuffled[:n_train]
x_eval = x_shuffled[n_train:]
n_sample = int(n_train * sample_fraction)
sample_indices = t.randperm(n_train)[:n_sample]
x_sample = x_train[sample_indices]
optimizer = t.optim.AdamW(model.parameters(), lr=lr)
metrics = {
'train_loss': [],
'eval_loss': [],
'W': [],
'b': [],
'steps': [],
'x_sample': [],
'y_sample': []
}
for epoch in tqdm(range(num_epochs)):
optimizer.zero_grad()
y_pred = model(x_train)
loss = compute_loss(x_train, y_pred, importances)
loss.backward()
optimizer.step()
metrics['train_loss'].append(loss.item())
metrics['W'].append(model.W.detach().clone())
metrics['b'].append(model.b.detach().clone())
metrics['steps'].append(epoch)
with t.no_grad():
y_sample_pred = model(x_sample)
metrics['x_sample'].append(x_sample.detach().clone())
metrics['y_sample'].append(y_sample_pred.detach().clone())
if epoch % 50 == 0:
with t.no_grad():
y_eval_pred = model(x_eval)
eval_loss = compute_loss(x_eval, y_eval_pred, importances)
metrics['eval_loss'].append(eval_loss.item())
else:
metrics['eval_loss'].append(None)
return metrics
Show code cell content
def plot_training_dashboard(x: t.Tensor, metrics: dict, margin: float = 0.2, frame_step: int = 100):
import numpy as np
x_np = x.cpu().numpy() if x.is_cuda or x.device.type == 'mps' else x.numpy()
colors = px.colors.qualitative.Plotly
input_color = colors[0]
output_color = colors[1]
connection_color = colors[6]
background_color = colors[7]
eval_steps = [i for i, loss in enumerate(metrics['eval_loss']) if loss is not None]
eval_losses = [loss for loss in metrics['eval_loss'] if loss is not None]
arrow_scale = 0.3
all_mins_x = []
all_maxs_x = []
all_mins_y = []
all_maxs_y = []
for step in range(0, len(metrics['steps']), frame_step):
W_val = metrics['W'][step].cpu().numpy() if metrics['W'][step].is_cuda or metrics['W'][step].device.type == 'mps' else metrics['W'][step].numpy()
b_val = metrics['b'][step].cpu().numpy() if metrics['b'][step].is_cuda or metrics['b'][step].device.type == 'mps' else metrics['b'][step].numpy()
y_sample_step = metrics['y_sample'][step]
y_sample_np = y_sample_step.cpu().numpy() if y_sample_step.is_cuda or y_sample_step.device.type == 'mps' else y_sample_step.numpy()
w_vec = W_val.squeeze()
w_normalized = w_vec / (np.linalg.norm(w_vec) + 1e-8)
x_sample_step_calc = metrics['x_sample'][step]
x_sample_np_calc = x_sample_step_calc.cpu().numpy() if x_sample_step_calc.is_cuda or x_sample_step_calc.device.type == 'mps' else x_sample_step_calc.numpy()
all_points_x = np.concatenate([
x_sample_np_calc[:, 0],
y_sample_np[:, 0],
[b_val[0]],
[b_val[0] + w_normalized[0] * arrow_scale]
])
all_points_y = np.concatenate([
x_sample_np_calc[:, 1],
y_sample_np[:, 1],
[b_val[1]],
[b_val[1] + w_normalized[1] * arrow_scale]
])
all_mins_x.append(all_points_x.min())
all_maxs_x.append(all_points_x.max())
all_mins_y.append(all_points_y.min())
all_maxs_y.append(all_points_y.max())
global_min_x = min(all_mins_x)
global_max_x = max(all_maxs_x)
global_min_y = min(all_mins_y)
global_max_y = max(all_maxs_y)
global_min = min(global_min_x, global_min_y)
global_max = max(global_max_x, global_max_y)
global_range = global_max - global_min
axis_lim = [global_min - margin * global_range, global_max + margin * global_range]
axis_lim = [max(axis_lim[0], -1.5), min(axis_lim[1], 1.5)]
frames = []
for step in range(0, len(metrics['steps']), frame_step):
W_val = metrics['W'][step].cpu().numpy() if metrics['W'][step].is_cuda or metrics['W'][step].device.type == 'mps' else metrics['W'][step].numpy()
b_val = metrics['b'][step].cpu().numpy() if metrics['b'][step].is_cuda or metrics['b'][step].device.type == 'mps' else metrics['b'][step].numpy()
x_sample_step = metrics['x_sample'][step]
y_sample_step = metrics['y_sample'][step]
x_sample_np = x_sample_step.cpu().numpy() if x_sample_step.is_cuda or x_sample_step.device.type == 'mps' else x_sample_step.numpy()
y_sample_np = y_sample_step.cpu().numpy() if y_sample_step.is_cuda or y_sample_step.device.type == 'mps' else y_sample_step.numpy()
w_vec = W_val.squeeze()
w_str = f"[{w_vec[0]:.2f}, {w_vec[1]:.2f}]"
b_str = f"[{b_val[0]:.2f}, {b_val[1]:.2f}]"
w_normalized = w_vec / (t.norm(metrics['W'][step]).item() + 1e-8)
line_extend = 2.0
line_start = b_val - w_normalized * line_extend
line_end = b_val + w_normalized * line_extend
train_loss = metrics['train_loss'][step]
connection_lines_x = []
connection_lines_y = []
for i in range(len(x_sample_np)):
connection_lines_x.extend([x_sample_np[i, 0], y_sample_np[i, 0], None])
connection_lines_y.extend([x_sample_np[i, 1], y_sample_np[i, 1], None])
frame_data = [
go.Scatter(
x=connection_lines_x,
y=connection_lines_y,
mode='lines',
line=dict(color=connection_color, width=1, dash='dot'),
name='Transformations',
showlegend=False,
xaxis='x1',
yaxis='y1'
),
go.Scatter(
x=x_sample_np[:, 0],
y=x_sample_np[:, 1],
mode='markers',
marker=dict(opacity=0.8, size=6, color=input_color),
name='Input Sample',
showlegend=False,
xaxis='x1',
yaxis='y1'
),
go.Scatter(
x=y_sample_np[:, 0],
y=y_sample_np[:, 1],
mode='markers',
marker=dict(opacity=0.8, size=6, color=output_color),
name='Output Points',
showlegend=False,
xaxis='x1',
yaxis='y1'
),
go.Scatter(
x=[line_start[0], line_end[0]],
y=[line_start[1], line_end[1]],
mode='lines',
line=dict(color='gray', width=1, dash='dot'),
name='W direction',
showlegend=False,
xaxis='x1',
yaxis='y1'
),
go.Scatter(
x=metrics['steps'][:step+1],
y=metrics['train_loss'][:step+1],
mode='lines',
name='Train Loss',
line=dict(color=input_color, width=2),
showlegend=True,
legendgroup='loss',
xaxis='x2',
yaxis='y2'
),
go.Scatter(
x=[s for s in eval_steps if s <= step],
y=[eval_losses[i] for i, s in enumerate(eval_steps) if s <= step],
mode='lines+markers',
name='Eval Loss',
line=dict(color=output_color, width=2),
marker=dict(size=6, color=output_color),
showlegend=True,
legendgroup='loss',
xaxis='x2',
yaxis='y2'
),
go.Scatter(
x=[step],
y=[train_loss],
mode='markers',
marker=dict(size=12, color=input_color, symbol='circle'),
name='Current Step',
showlegend=False,
xaxis='x2',
yaxis='y2'
)
]
frame_annotations = [
dict(
x=b_val[0],
y=b_val[1],
ax=0,
ay=0,
xref='x1',
yref='y1',
axref='x1',
ayref='y1',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=3,
arrowcolor='blue'
),
dict(
x=b_val[0] + w_normalized[0] * arrow_scale,
y=b_val[1] + w_normalized[1] * arrow_scale,
ax=b_val[0],
ay=b_val[1],
xref='x1',
yref='y1',
axref='x1',
ayref='y1',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=3,
arrowcolor='black'
)
]
frames.append(go.Frame(
data=frame_data,
name=str(step),
layout=go.Layout(
title_text=f"Training<br><sub>Step {step}, MSE = {train_loss:.4f}, W = {w_str}, b = {b_str}</sub>",
annotations=frame_annotations
)
))
fig = make_subplots(
rows=1, cols=2,
specs=[[{"type": "scatter"}, {"type": "scatter"}]],
horizontal_spacing=0.15,
column_widths=[0.5, 0.5]
)
x_sample_0 = metrics['x_sample'][0]
y_sample_0 = metrics['y_sample'][0]
x_sample_0_np = x_sample_0.cpu().numpy() if x_sample_0.is_cuda or x_sample_0.device.type == 'mps' else x_sample_0.numpy()
y_sample_0_np = y_sample_0.cpu().numpy() if y_sample_0.is_cuda or y_sample_0.device.type == 'mps' else y_sample_0.numpy()
connection_lines_x_0 = []
connection_lines_y_0 = []
for i in range(len(x_sample_0_np)):
connection_lines_x_0.extend([x_sample_0_np[i, 0], y_sample_0_np[i, 0], None])
connection_lines_y_0.extend([x_sample_0_np[i, 1], y_sample_0_np[i, 1], None])
fig.add_trace(
go.Scatter(
x=connection_lines_x_0,
y=connection_lines_y_0,
mode='lines',
line=dict(color=connection_color, width=1, dash='dot'),
name='Transformations',
showlegend=False
),
row=1, col=1
)
fig.add_trace(
go.Scatter(
x=x_sample_0_np[:, 0],
y=x_sample_0_np[:, 1],
mode='markers',
marker=dict(opacity=0.8, size=6, color=input_color),
name='Input Sample',
showlegend=False
),
row=1, col=1
)
fig.add_trace(
go.Scatter(
x=y_sample_0_np[:, 0],
y=y_sample_0_np[:, 1],
mode='markers',
marker=dict(opacity=0.8, size=6, color=output_color),
name='Output Points',
showlegend=False
),
row=1, col=1
)
fig.add_trace(
go.Scatter(
x=[0],
y=[metrics['train_loss'][0]],
mode='lines',
name='Train Loss',
line=dict(color=input_color, width=2),
showlegend=True,
legendgroup='loss'
),
row=1, col=2
)
fig.add_trace(
go.Scatter(
x=[],
y=[],
mode='lines+markers',
name='Eval Loss',
line=dict(color=output_color, width=2),
marker=dict(size=6, color=output_color),
showlegend=True,
legendgroup='loss'
),
row=1, col=2
)
fig.add_trace(
go.Scatter(
x=[0],
y=[metrics['train_loss'][0]],
mode='markers',
marker=dict(size=12, color=input_color, symbol='circle'),
name='Current Step',
showlegend=False
),
row=1, col=2
)
W_0 = metrics['W'][0].cpu().numpy() if metrics['W'][0].is_cuda or metrics['W'][0].device.type == 'mps' else metrics['W'][0].numpy()
b_0 = metrics['b'][0].cpu().numpy() if metrics['b'][0].is_cuda or metrics['b'][0].device.type == 'mps' else metrics['b'][0].numpy()
w_0_vec = W_0.squeeze()
w_0_str = f"[{w_0_vec[0]:.2f}, {w_0_vec[1]:.2f}]"
b_0_str = f"[{b_0[0]:.2f}, {b_0[1]:.2f}]"
fig.update_xaxes(title_text="x₁", range=axis_lim, row=1, col=1)
fig.update_yaxes(title_text="x₂", range=axis_lim, scaleanchor="x", scaleratio=1, row=1, col=1)
fig.update_xaxes(title_text="Step", row=1, col=2)
fig.update_yaxes(title_text="Loss", row=1, col=2)
fig.update_layout(
title=f"Training<br><sub>Step 0, MSE = {metrics['train_loss'][0]:.4f}, W = {w_0_str}, b = {b_0_str}</sub>",
title_y=0.92,
width=900,
height=600,
hovermode='closest',
legend=dict(x=0.55, y=0.95),
updatemenus=[dict(
type="buttons",
buttons=[
dict(label="Play", method="animate", args=[None, {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}]),
dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}])
],
x=0.1,
y=-0.1
)],
sliders=[dict(
active=0,
steps=[dict(args=[[f.name], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
label=f.name,
method="animate") for f in frames],
x=0.1,
len=0.8,
y=-0.1
)]
)
fig.frames = frames
fig.show()
Experiment 1: Equal importances, no sparsity#
In this first experiment, we train the linear model with equal importances on both features and no sparsity in the data.
The model converges to a solution where the line passes roughly through the center of the unit square. The angle of the line changes between training runs—different angles won’t yield substantially different mean squared errors, since the data is symmetric. The key observation is that the reconstructed points all lie along a one-dimensional line in the two-dimensional output space.
The resulting MSE confirms what we expected from our earlier analysis: projecting 2D data onto a 1D line destroys information, and no amount of optimization can recover it with a purely linear model. The loss is substantial because the model is forced to make a compromise—it can’t simultaneously represent variation in both dimensions.
t.manual_seed(44)
device = get_device()
importances = t.ones(DEF_INPUT_DIM).to(device)
x = generate_data().to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=1500, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)
Experiment 2: Unequal importances, no sparsity#
Now we train with unequal importances: one feature matters more for our loss function than the other.
The model makes a rational trade-off. It dedicates its single degree of freedom—the direction of \(W\)—to representing the important feature as accurately as possible. The less important feature is approximated using the bias term \(b\), which essentially sets it to its mean value across the dataset. Since the data is uniform on [0,1], the unimportant dimension gets approximated as approximately 0.5.
This is still far from perfect. The model can represent one dimension well, but the other dimension shows no variation at all in the reconstruction. We’re still bound by the fundamental limitation: a one-dimensional representation can’t capture two-dimensional variation.
t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 0.1]).to(device)
x = generate_data().to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=3300, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)
Introducing sparsity#
We’ll now introduce sparsity into our data. With a sparsity parameter \(s = 0.8\), each feature independently has an 80% chance of being exactly zero.
This changes the structure of the problem. When \(x_1 = 0\) and \(x_2 \neq 0\), the projection \(h = Wx\) only captures information from \(x_2\). There’s no interference from \(x_1\) in the down-projection. Similarly, when \(x_2 = 0\) and \(x_1 \neq 0\), only \(x_1\) contributes to \(h\).
The question is whether a linear model can exploit this structure to achieve better reconstruction. With 80% sparsity, most of the time at least one of the features is zero. In these cases, the projection is effectively one-dimensional, which seems like it should make the reconstruction problem easier. But can gradient descent discover a way to take advantage of this?
Experiment 3: Sparse data, unequal importances#
With sparse data and unequal importances, the model behaves similarly to the non-sparse case from Experiment 2. It uses the direction of \(W\) to represent the more important feature, and uses the bias \(b\) to approximate the less important one.
The main difference is in the value of the bias. In the non-sparse case, the bias was around 0.5, since that’s the mean of a uniform distribution on [0,1]. Here, with 80% sparsity, most values are zero, so the mean is much closer to zero. The bias reflects this by taking a value closer to 0.
The sparsity helps reduce the overall loss, but the linear model still can’t do anything fundamentally different from what it did without sparsity. It’s still making the same basic trade-off: represent one dimension well, approximate the other with its mean.
t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 0.1]).to(device)
x = generate_data(sparsity=0.8).to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=2500, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)
Experiment 4: Sparse data, equal importances#
This configuration—sparse data with equal importances—turns out to be less stable than the others. The model sometimes gets stuck in local minima, and different training runs can produce quite different results.
The issue is that the model faces ambiguity. With equal importances, it has no clear signal about which dimension to prioritize. Should it commit fully to representing one dimension well, or should it try to hedge and partially represent both? The sparse structure of the data means that committing to one dimension could work well, but the symmetry in the loss function means there’s no gradient pushing it clearly in either direction.
As a result, you’ll often see solutions that don’t fully commit to either dimension, or training dynamics that oscillate between different strategies. This instability is itself informative: it shows that without additional structure (like unequal importances or, as we’ll see next, a nonlinearity), the model struggles to find a consistent solution even when the data has favorable properties like sparsity.
t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 1.0]).to(device)
x = generate_data(sparsity=0.8).to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=10000, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)