Even Smaller Models of Superposition#

Introduction: The superposition problem#

Anthropic’s 2022 paper on toy models of superposition makes a counterintuitive claim: if you have two statistically independent variables, you can linearly combine them into a single number in a way that allows you to reconstruct both variables separately later.

For someone like me with a background in linear algebra and information theory, this seems impossible. If you project points from a two-dimensional space onto a one-dimensional line, you end up with a representation that only has one degree of freedom. There is no way that you can recover two separate dimensions from there.

The trick, it turns out, involves three ingredients: sparsity (most values are zero most of the time), positivity (values lie in a known range like [0,1]), and adding a simple nonlinearity (specifically, something as simple as a ReLU activation). Together, these properties allow us to do something that seems like it violates basic principles of linear algebra.

In this notebook, we’ll build up an understanding of how this works through a series of experiments with progressively more structure.

Data generation#

Our data consists of pairs of values \(x = (x_1, x_2)\), which we’ll call features. Each feature has three key properties:

  • Uniform distribution: Both features are uniformly distributed in the interval \([0, 1]\).

  • Statistical independence: The value of \(x_1\) tells us nothing about the value of \(x_2\), and vice versa.

  • Sparsity: With probability \(s\) (the sparsity parameter), each feature is set to exactly zero. This happens independently for each feature.

For now, we’ll start with \(s = 0\), meaning no sparsity—all points lie in the interior of the unit square.

We’ll also introduce the concept of importances: we can weight how much we care about accurately reconstructing each feature. If one feature has higher importance, we’ll penalize reconstruction errors on that feature more heavily in our loss function. This will be useful for understanding the trade-offs the model makes when it can’t perfectly reconstruct both dimensions.

import torch as t
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "notebook_connected"

DEF_N_POINTS = 1000
DEF_N_FEATURES = 2
DEF_SPARSITY = 0

def generate_data(n_points=DEF_N_POINTS, n_features=DEF_N_FEATURES, sparsity=DEF_SPARSITY):
    # Points are unuformly distributed in [0,1]
    # Independently across features, x_i = 0 with probability S
    sparsity_mask = t.rand(n_points, n_features) < sparsity
    x = t.rand(n_points, n_features)
    x[sparsity_mask] = 0
    return x

def plot_data(x, title=None):
    fig = px.scatter(x=x[:, 0], y=x[:, 1], opacity=0.5)
    fig.update_layout(width=600, height=600, title=title)
    fig.update_xaxes(title_text="x<sub>1</sub>")
    fig.update_yaxes(title_text="x<sub>2</sub>")
    fig.show()

data = generate_data()
plot_data(data, title="Initial data distribution")

Why this should be impossible#

Look at the scatter plot above. The points fill a two-dimensional region. If we project these points onto any one-dimensional line, we collapse them down to a single axis. Then, when we try to reconstruct the original points, we can only place them somewhere along that same line.

A line embedded in a two-dimensional space only has one degree of freedom. We can slide points back and forth along it, but we can’t move them perpendicular to it. This means that the reconstructed points will form a one-dimensional structure, no matter how we choose the line.

The original data, by contrast, has two degrees of freedom—it occupies a full two-dimensional region. There’s no way to perfectly reconstruct a two-dimensional distribution from a one-dimensional projection using only linear operations. Information has been destroyed in the compression step, and linear operations can’t create information out of nothing.

Or so it seems.

Visualizing the transformation#

The plot below shows what this transformation looks like for a specific choice of \(W\) and \(b\). The visualization has three panels:

  • Left panel: The input points \(x\) in two-dimensional space.

  • Middle panel: The hidden dimension \(h = Wx\), which compresses the input down to a single number. Each input point is mapped to a position on the one-dimensional axis.

  • Right panel: The reconstructed points \(y = W^T h + b\) back in two-dimensional space.

In the left and right panels, the arrows show the parameters of the transformation. The black arrow shows the direction of \(W\), which determines how we project onto the line. The blue arrow shows \(b\), which determines where the line is centered in the output space. Points are colored by their hidden value \(h\), making it easier to see how the compression and reconstruction work.

Hide code cell source
from plotly.subplots import make_subplots
import plotly.graph_objects as go

def plot_mapping(x: t.Tensor, h: t.Tensor, y: t.Tensor, W: t.Tensor, b: t.Tensor, title: str = None, margin: float = 0.1):
    mse = ((x - y) ** 2).mean().item()
    
    h_flat = h.numpy().flatten()
    
    w_vec = W.squeeze().numpy()
    w_str = f"[{w_vec[0]:.2f}, {w_vec[1]:.2f}]"
    b_vec = b.numpy()
    b_str = f"[{b_vec[0]:.2f}, {b_vec[1]:.2f}]"
    
    subtitle = f"W = {w_str}, b = {b_str}, MSE = {mse:.4f}"
    
    x_np = x.numpy()
    y_np = y.numpy()
    
    x_range_0 = x_np[:, 0].max() - x_np[:, 0].min()
    x_range_1 = x_np[:, 1].max() - x_np[:, 1].min()
    x_lim_0 = [x_np[:, 0].min() - margin * x_range_0, x_np[:, 0].max() + margin * x_range_0]
    x_lim_1 = [x_np[:, 1].min() - margin * x_range_1, x_np[:, 1].max() + margin * x_range_1]
    
    y_range_0 = y_np[:, 0].max() - y_np[:, 0].min()
    y_range_1 = y_np[:, 1].max() - y_np[:, 1].min()
    y_lim_0 = [y_np[:, 0].min() - margin * y_range_0, y_np[:, 0].max() + margin * y_range_0]
    y_lim_1 = [y_np[:, 1].min() - margin * y_range_1, y_np[:, 1].max() + margin * y_range_1]
    
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=(
            "Input Points (X)",
            "Hidden Dimension (H)",
            "Output Points (Y)"
        ),
        specs=[[{"type": "scatter"}, {"type": "scatter"}, {"type": "scatter"}]],
        horizontal_spacing=0.1
    )
    
    fig.add_trace(
        go.Scatter(
            x=x[:, 0].numpy(),
            y=x[:, 1].numpy(),
            mode='markers',
            marker=dict(
                opacity=0.7,
                size=4,
                color=h_flat,
                colorscale='Viridis',
                showscale=False
            ),
            name='Input'
        ),
        row=1, col=1
    )
    
    w_normalized = w_vec / (t.norm(W).item() + 1e-8)
    arrow_scale = 0.3
    fig.add_annotation(
        x=b_vec[0] + w_normalized[0] * arrow_scale,
        y=b_vec[1] + w_normalized[1] * arrow_scale,
        ax=b_vec[0],
        ay=b_vec[1],
        xref='x1',
        yref='y1',
        axref='x1',
        ayref='y1',
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=1.5,
        arrowcolor='black',
        row=1, col=1
    )
    
    fig.add_annotation(
        x=b_vec[0] + w_normalized[0] * arrow_scale,
        y=b_vec[1] + w_normalized[1] * arrow_scale,
        text=f"W = {w_str}",
        xref='x1',
        yref='y1',
        showarrow=False,
        xanchor='left',
        yanchor='bottom',
        xshift=10,
        yshift=5,
        font=dict(size=10),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=h_flat,
            y=t.zeros_like(h).numpy().flatten(),
            mode='markers',
            marker=dict(
                opacity=0.7,
                size=4,
                color=h_flat,
                colorscale='Viridis',
                showscale=False
            ),
            name='Hidden'
        ),
        row=1, col=2
    )
    
    fig.add_trace(
        go.Scatter(
            x=y[:, 0].numpy(),
            y=y[:, 1].numpy(),
            mode='markers',
            marker=dict(
                opacity=0.7,
                size=4,
                color=h_flat,
                colorscale='Viridis',
                showscale=True,
                colorbar=dict(title="h")
            ),
            name='Output'
        ),
        row=1, col=3
    )
    
    fig.add_annotation(
        x=b_vec[0],
        y=b_vec[1],
        ax=0,
        ay=0,
        xref='x3',
        yref='y3',
        axref='x3',
        ayref='y3',
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=1.5,
        arrowcolor='blue',
        row=1, col=3
    )
    
    fig.add_annotation(
        x=b_vec[0],
        y=b_vec[1],
        text=f"b = {b_str}",
        xref='x3',
        yref='y3',
        showarrow=False,
        xanchor='left',
        yanchor='bottom',
        xshift=10,
        yshift=5,
        font=dict(size=10),
        row=1, col=3
    )
    
    fig.add_annotation(
        x=b_vec[0] + w_normalized[0] * arrow_scale,
        y=b_vec[1] + w_normalized[1] * arrow_scale,
        ax=b_vec[0],
        ay=b_vec[1],
        xref='x3',
        yref='y3',
        axref='x3',
        ayref='y3',
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=1.5,
        arrowcolor='black',
        row=1, col=3
    )
    
    fig.add_annotation(
        x=b_vec[0] + w_normalized[0] * arrow_scale,
        y=b_vec[1] + w_normalized[1] * arrow_scale,
        text=f"W = {w_str}",
        xref='x3',
        yref='y3',
        showarrow=False,
        xanchor='left',
        yanchor='bottom',
        xshift=10,
        yshift=5,
        font=dict(size=10),
        row=1, col=3
    )
    
    fig.update_xaxes(title_text="x<sub>1</sub>", range=x_lim_0, row=1, col=1)
    fig.update_yaxes(title_text="x<sub>2</sub>", range=x_lim_1, row=1, col=1)
    fig.update_xaxes(title_text="h", row=1, col=2)
    fig.update_yaxes(showticklabels=False, row=1, col=2)
    fig.update_xaxes(title_text="y<sub>1</sub>", range=y_lim_0, row=1, col=3)
    fig.update_yaxes(title_text="y<sub>2</sub>", range=y_lim_1, row=1, col=3)
    
    main_title = title if title else "Mapping Visualization"
    full_title = f"{main_title}<br><sub>{subtitle}</sub><br><br>"
    
    fig.update_layout(
        height=500,
        width=900,
        showlegend=False,
        title_text=full_title
    )
    
    fig.show()

    
import einops
W = t.tensor([[1.0,-1.0]], dtype=t.float32)
b = t.tensor([0.5, 0.5], dtype=t.float32)

def f(x):
    return einops.einsum(x, W, "b f, h f -> b h")

def g(h):
    return einops.einsum(h, W, "b h, h f -> b f") + b

x = generate_data()
h = f(x)
y = g(h)
plot_mapping(x, h, y, W, b)

To put it formally:

  • We have a dataset of 2-dimensional input points \(D: x \in \mathbb{R}^{2}\). We will denote each of the two components of a pair \(x\) as “features” \((x_1, x_2)\).

  • The points \(x\) are uniformly distributed between \([0,1]\) independently across both dimensions.

  • We want to find a mapping \(f: \mathbb{R}^{2} \rightarrow \mathbb{R}\) that gives us a single “hidden” value \(h = f(x)\) for each input pair \(x\).

  • The mapping \(f\) should have a “reverse mapping” \(g: \mathbb{R} \rightarrow \mathbb{R}^{2}\) that gives us an output pair \(y \in \mathbb{R}^{2}\) for each hidden value \(h\).

  • The mappings \(f\) and \(g\) should be chosen such that they minimize the mean square error between \(x\) and \(y\), i.e. \(L_{MSE} = \sum_{D}(x-y)^2 = \sum_{D}(x-g(f(x)))^2\)

Learning a linear model#

We’ll use PyTorch to learn the parameters \(W\) and \(b\) via gradient descent. The objective is to minimize the mean squared error between the input \(x\) and the reconstruction \(y\):

\[ \text{Loss} = \sum_{dataset}\sum_{i=1}^{2} \text{importance}_i \cdot (x_i - y_i)^2 \]

The importances allow us to weight the reconstruction error differently for each feature. When both importances are equal, we care equally about reconstructing both dimensions. When one importance is higher, we’re willing to accept larger errors on the less important dimension if it means better reconstruction of the more important one.

The model#

The LinearMapping class implements the compression and reconstruction we described earlier:

  • compute_hidden(x): Computes \(h = Wx\), projecting the input onto the one-dimensional hidden space.

  • compute_output(h): Computes \(y = W^T h + b\), reconstructing from the hidden dimension back to two dimensions.

  • forward(x): Combines both operations to map from input \(x\) to reconstruction \(y\).

The model is entirely linear—just matrix multiplication and addition, with no nonlinearities anywhere.

Side note: why is one of the two mappings affine but not both?

You might be wondering why \(f\) is linear and \(g\) is affine, that is, why only \(g\) has a bias term. This is because a composition of two affine functions can be rewritten as a composition of one linear and one affine function, with both translations being factored out into a single final translation. Compare both cases below:

Case 1: one linear function and one affine function:

  • \(h = f(x) = Wx\)

  • \(y = g(h) = W^{T}h + b\)

  • Composition: \(y = g(f(x)) = W^{T}Wx + b\)

Case 2: two affine functions:

  • \(h = f(x) = Wx + b_h\)

  • \(y = g(h) = W^{T}h + b_y\)

  • Composition: \(y = g(f(x)) = W^{T}(Wx + b_h) + b_y = W^{T}Wx + W^{T}b_h + b_y\)

  • We can rewrite \(W^{T}b_h + b_y\), which is a constant vector in \(\mathbb{R}^2\), as the combined bias \(b_c\) for convenience.

  • This leaves us with \(y = W^{T}Wx + b_c\).

Both cases have exactly the same form—any bias in the first mapping can be absorbed as a constant term in the bias of the second mapping.

Hide code cell content
from torch import nn
from tqdm.notebook import tqdm

def get_device():
    if t.backends.mps.is_available():
        return t.device("mps")
    elif t.cuda.is_available():
        return t.device("cuda")
    else:
        return t.device("cpu")

DEVICE = get_device()
from torch import nn

DEF_INPUT_DIM = 2
DEF_HIDDEN_DIM = 1

class LinearMapping(nn.Module):
    def __init__(self, input_dim: DEF_INPUT_DIM, hidden_dim: DEF_HIDDEN_DIM):
        super().__init__()
        W = t.empty(hidden_dim, input_dim)
        b = t.empty(input_dim)
        nn.init.xavier_normal_(W)
        nn.init.zeros_(b)
        self.W = nn.Parameter(W)
        self.b = nn.Parameter(b)

    def compute_hidden(self, x):
        return einops.einsum(x, self.W, "b f, h f -> b h")

    def compute_output(self, h):
        return einops.einsum(h, self.W, "b h, h f -> b f") + self.b

    def forward(self, x):
        h = self.compute_hidden(x)
        y = self.compute_output(h)
        return y

The training loop#

The train function uses standard gradient descent with the AdamW optimizer. It splits the data into training and evaluation sets, then iteratively updates \(W\) and \(b\) to minimize the weighted mean squared error.

Throughout training, we track several metrics: the training loss, the parameters \(W\) and \(b\), and samples of the input and reconstructed points. The visualization in the next cells will show an animation of how these evolve during training, with the left panel showing the transformation in 2D space and the right panel showing the loss curves.

from torch import nn
from tqdm.notebook import tqdm

DEF_IMPORTANCES = t.ones(DEF_INPUT_DIM)
DEF_LR = 1e-4
DEF_NUM_EPOCHS = 1000

def compute_loss(x, y, importances=DEF_IMPORTANCES):
    return ((x - y) ** 2 * importances).mean()

def train(model, x, train_eval_split=0.2, importances=DEF_IMPORTANCES, lr=DEF_LR, num_epochs=DEF_NUM_EPOCHS, sample_fraction=0.3):
    indices = t.randperm(len(x))
    x_shuffled = x[indices]
    
    n_train = int(len(x) * (1 - train_eval_split))
    x_train = x_shuffled[:n_train]
    x_eval = x_shuffled[n_train:]
    
    n_sample = int(n_train * sample_fraction)
    sample_indices = t.randperm(n_train)[:n_sample]
    x_sample = x_train[sample_indices]
    
    optimizer = t.optim.AdamW(model.parameters(), lr=lr)
    
    metrics = {
        'train_loss': [],
        'eval_loss': [],
        'W': [],
        'b': [],
        'steps': [],
        'x_sample': [],
        'y_sample': []
    }
    
    for epoch in tqdm(range(num_epochs)):
        optimizer.zero_grad()
        y_pred = model(x_train)
        loss = compute_loss(x_train, y_pred, importances)
        loss.backward()
        optimizer.step()
        
        metrics['train_loss'].append(loss.item())
        metrics['W'].append(model.W.detach().clone())
        metrics['b'].append(model.b.detach().clone())
        metrics['steps'].append(epoch)
        
        with t.no_grad():
            y_sample_pred = model(x_sample)
            metrics['x_sample'].append(x_sample.detach().clone())
            metrics['y_sample'].append(y_sample_pred.detach().clone())
        
        if epoch % 50 == 0:
            with t.no_grad():
                y_eval_pred = model(x_eval)
                eval_loss = compute_loss(x_eval, y_eval_pred, importances)
                metrics['eval_loss'].append(eval_loss.item())
        else:
            metrics['eval_loss'].append(None)
    
    return metrics
Hide code cell content
def plot_training_dashboard(x: t.Tensor, metrics: dict, margin: float = 0.2, frame_step: int = 100):
    import numpy as np
    
    x_np = x.cpu().numpy() if x.is_cuda or x.device.type == 'mps' else x.numpy()
    
    colors = px.colors.qualitative.Plotly
    input_color = colors[0]
    output_color = colors[1]
    connection_color = colors[6]
    background_color = colors[7]
    
    eval_steps = [i for i, loss in enumerate(metrics['eval_loss']) if loss is not None]
    eval_losses = [loss for loss in metrics['eval_loss'] if loss is not None]
    
    arrow_scale = 0.3
    
    all_mins_x = []
    all_maxs_x = []
    all_mins_y = []
    all_maxs_y = []
    
    for step in range(0, len(metrics['steps']), frame_step):
        W_val = metrics['W'][step].cpu().numpy() if metrics['W'][step].is_cuda or metrics['W'][step].device.type == 'mps' else metrics['W'][step].numpy()
        b_val = metrics['b'][step].cpu().numpy() if metrics['b'][step].is_cuda or metrics['b'][step].device.type == 'mps' else metrics['b'][step].numpy()
        y_sample_step = metrics['y_sample'][step]
        y_sample_np = y_sample_step.cpu().numpy() if y_sample_step.is_cuda or y_sample_step.device.type == 'mps' else y_sample_step.numpy()
        
        w_vec = W_val.squeeze()
        w_normalized = w_vec / (np.linalg.norm(w_vec) + 1e-8)
        
        x_sample_step_calc = metrics['x_sample'][step]
        x_sample_np_calc = x_sample_step_calc.cpu().numpy() if x_sample_step_calc.is_cuda or x_sample_step_calc.device.type == 'mps' else x_sample_step_calc.numpy()
        
        all_points_x = np.concatenate([
            x_sample_np_calc[:, 0],
            y_sample_np[:, 0],
            [b_val[0]],
            [b_val[0] + w_normalized[0] * arrow_scale]
        ])
        all_points_y = np.concatenate([
            x_sample_np_calc[:, 1],
            y_sample_np[:, 1],
            [b_val[1]],
            [b_val[1] + w_normalized[1] * arrow_scale]
        ])
        
        all_mins_x.append(all_points_x.min())
        all_maxs_x.append(all_points_x.max())
        all_mins_y.append(all_points_y.min())
        all_maxs_y.append(all_points_y.max())
    
    global_min_x = min(all_mins_x)
    global_max_x = max(all_maxs_x)
    global_min_y = min(all_mins_y)
    global_max_y = max(all_maxs_y)
    
    global_min = min(global_min_x, global_min_y)
    global_max = max(global_max_x, global_max_y)
    
    global_range = global_max - global_min
    axis_lim = [global_min - margin * global_range, global_max + margin * global_range]
    
    axis_lim = [max(axis_lim[0], -1.5), min(axis_lim[1], 1.5)]
    
    frames = []
    
    for step in range(0, len(metrics['steps']), frame_step):
        W_val = metrics['W'][step].cpu().numpy() if metrics['W'][step].is_cuda or metrics['W'][step].device.type == 'mps' else metrics['W'][step].numpy()
        b_val = metrics['b'][step].cpu().numpy() if metrics['b'][step].is_cuda or metrics['b'][step].device.type == 'mps' else metrics['b'][step].numpy()
        
        x_sample_step = metrics['x_sample'][step]
        y_sample_step = metrics['y_sample'][step]
        x_sample_np = x_sample_step.cpu().numpy() if x_sample_step.is_cuda or x_sample_step.device.type == 'mps' else x_sample_step.numpy()
        y_sample_np = y_sample_step.cpu().numpy() if y_sample_step.is_cuda or y_sample_step.device.type == 'mps' else y_sample_step.numpy()
        
        w_vec = W_val.squeeze()
        w_str = f"[{w_vec[0]:.2f}, {w_vec[1]:.2f}]"
        b_str = f"[{b_val[0]:.2f}, {b_val[1]:.2f}]"
        w_normalized = w_vec / (t.norm(metrics['W'][step]).item() + 1e-8)
        
        line_extend = 2.0
        line_start = b_val - w_normalized * line_extend
        line_end = b_val + w_normalized * line_extend
        
        train_loss = metrics['train_loss'][step]
        
        connection_lines_x = []
        connection_lines_y = []
        for i in range(len(x_sample_np)):
            connection_lines_x.extend([x_sample_np[i, 0], y_sample_np[i, 0], None])
            connection_lines_y.extend([x_sample_np[i, 1], y_sample_np[i, 1], None])
        
        frame_data = [
            go.Scatter(
                x=connection_lines_x,
                y=connection_lines_y,
                mode='lines',
                line=dict(color=connection_color, width=1, dash='dot'),
                name='Transformations',
                showlegend=False,
                xaxis='x1',
                yaxis='y1'
            ),
            go.Scatter(
                x=x_sample_np[:, 0],
                y=x_sample_np[:, 1],
                mode='markers',
                marker=dict(opacity=0.8, size=6, color=input_color),
                name='Input Sample',
                showlegend=False,
                xaxis='x1',
                yaxis='y1'
            ),
            go.Scatter(
                x=y_sample_np[:, 0],
                y=y_sample_np[:, 1],
                mode='markers',
                marker=dict(opacity=0.8, size=6, color=output_color),
                name='Output Points',
                showlegend=False,
                xaxis='x1',
                yaxis='y1'
            ),
            go.Scatter(
                x=[line_start[0], line_end[0]],
                y=[line_start[1], line_end[1]],
                mode='lines',
                line=dict(color='gray', width=1, dash='dot'),
                name='W direction',
                showlegend=False,
                xaxis='x1',
                yaxis='y1'
            ),
            go.Scatter(
                x=metrics['steps'][:step+1],
                y=metrics['train_loss'][:step+1],
                mode='lines',
                name='Train Loss',
                line=dict(color=input_color, width=2),
                showlegend=True,
                legendgroup='loss',
                xaxis='x2',
                yaxis='y2'
            ),
            go.Scatter(
                x=[s for s in eval_steps if s <= step],
                y=[eval_losses[i] for i, s in enumerate(eval_steps) if s <= step],
                mode='lines+markers',
                name='Eval Loss',
                line=dict(color=output_color, width=2),
                marker=dict(size=6, color=output_color),
                showlegend=True,
                legendgroup='loss',
                xaxis='x2',
                yaxis='y2'
            ),
            go.Scatter(
                x=[step],
                y=[train_loss],
                mode='markers',
                marker=dict(size=12, color=input_color, symbol='circle'),
                name='Current Step',
                showlegend=False,
                xaxis='x2',
                yaxis='y2'
            )
        ]
        
        frame_annotations = [
            dict(
                x=b_val[0],
                y=b_val[1],
                ax=0,
                ay=0,
                xref='x1',
                yref='y1',
                axref='x1',
                ayref='y1',
                showarrow=True,
                arrowhead=2,
                arrowsize=1,
                arrowwidth=3,
                arrowcolor='blue'
            ),
            dict(
                x=b_val[0] + w_normalized[0] * arrow_scale,
                y=b_val[1] + w_normalized[1] * arrow_scale,
                ax=b_val[0],
                ay=b_val[1],
                xref='x1',
                yref='y1',
                axref='x1',
                ayref='y1',
                showarrow=True,
                arrowhead=2,
                arrowsize=1,
                arrowwidth=3,
                arrowcolor='black'
            )
        ]
        
        frames.append(go.Frame(
            data=frame_data,
            name=str(step),
            layout=go.Layout(
                title_text=f"Training<br><sub>Step {step}, MSE = {train_loss:.4f}, W = {w_str}, b = {b_str}</sub>",
                annotations=frame_annotations
            )
        ))
    
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{"type": "scatter"}, {"type": "scatter"}]],
        horizontal_spacing=0.15,
        column_widths=[0.5, 0.5]
    )
    
    x_sample_0 = metrics['x_sample'][0]
    y_sample_0 = metrics['y_sample'][0]
    x_sample_0_np = x_sample_0.cpu().numpy() if x_sample_0.is_cuda or x_sample_0.device.type == 'mps' else x_sample_0.numpy()
    y_sample_0_np = y_sample_0.cpu().numpy() if y_sample_0.is_cuda or y_sample_0.device.type == 'mps' else y_sample_0.numpy()
    
    connection_lines_x_0 = []
    connection_lines_y_0 = []
    for i in range(len(x_sample_0_np)):
        connection_lines_x_0.extend([x_sample_0_np[i, 0], y_sample_0_np[i, 0], None])
        connection_lines_y_0.extend([x_sample_0_np[i, 1], y_sample_0_np[i, 1], None])
    
    fig.add_trace(
        go.Scatter(
            x=connection_lines_x_0,
            y=connection_lines_y_0,
            mode='lines',
            line=dict(color=connection_color, width=1, dash='dot'),
            name='Transformations',
            showlegend=False
        ),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=x_sample_0_np[:, 0],
            y=x_sample_0_np[:, 1],
            mode='markers',
            marker=dict(opacity=0.8, size=6, color=input_color),
            name='Input Sample',
            showlegend=False
        ),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=y_sample_0_np[:, 0],
            y=y_sample_0_np[:, 1],
            mode='markers',
            marker=dict(opacity=0.8, size=6, color=output_color),
            name='Output Points',
            showlegend=False
        ),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=[0],
            y=[metrics['train_loss'][0]],
            mode='lines',
            name='Train Loss',
            line=dict(color=input_color, width=2),
            showlegend=True,
            legendgroup='loss'
        ),
        row=1, col=2
    )
    
    fig.add_trace(
        go.Scatter(
            x=[],
            y=[],
            mode='lines+markers',
            name='Eval Loss',
            line=dict(color=output_color, width=2),
            marker=dict(size=6, color=output_color),
            showlegend=True,
            legendgroup='loss'
        ),
        row=1, col=2
    )
    
    fig.add_trace(
        go.Scatter(
            x=[0],
            y=[metrics['train_loss'][0]],
            mode='markers',
            marker=dict(size=12, color=input_color, symbol='circle'),
            name='Current Step',
            showlegend=False
        ),
        row=1, col=2
    )
    
    W_0 = metrics['W'][0].cpu().numpy() if metrics['W'][0].is_cuda or metrics['W'][0].device.type == 'mps' else metrics['W'][0].numpy()
    b_0 = metrics['b'][0].cpu().numpy() if metrics['b'][0].is_cuda or metrics['b'][0].device.type == 'mps' else metrics['b'][0].numpy()
    w_0_vec = W_0.squeeze()
    w_0_str = f"[{w_0_vec[0]:.2f}, {w_0_vec[1]:.2f}]"
    b_0_str = f"[{b_0[0]:.2f}, {b_0[1]:.2f}]"
    
    fig.update_xaxes(title_text="x₁", range=axis_lim, row=1, col=1)
    fig.update_yaxes(title_text="x₂", range=axis_lim, scaleanchor="x", scaleratio=1, row=1, col=1)
    fig.update_xaxes(title_text="Step", row=1, col=2)
    fig.update_yaxes(title_text="Loss", row=1, col=2)
    
    fig.update_layout(
        title=f"Training<br><sub>Step 0, MSE = {metrics['train_loss'][0]:.4f}, W = {w_0_str}, b = {b_0_str}</sub>",
        title_y=0.92,
        width=900,
        height=600,
        hovermode='closest',
        legend=dict(x=0.55, y=0.95),
        updatemenus=[dict(
            type="buttons",
            buttons=[
                dict(label="Play", method="animate", args=[None, {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}]),
                dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}])
            ],
            x=0.1,
            y=-0.1
        )],
        sliders=[dict(
            active=0,
            steps=[dict(args=[[f.name], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                       label=f.name,
                       method="animate") for f in frames],
            x=0.1,
            len=0.8,
            y=-0.1
        )]
    )
    
    fig.frames = frames
    fig.show()

Experiment 1: Equal importances, no sparsity#

In this first experiment, we train the linear model with equal importances on both features and no sparsity in the data.

The model converges to a solution where the line passes roughly through the center of the unit square. The angle of the line changes between training runs—different angles won’t yield substantially different mean squared errors, since the data is symmetric. The key observation is that the reconstructed points all lie along a one-dimensional line in the two-dimensional output space.

The resulting MSE confirms what we expected from our earlier analysis: projecting 2D data onto a 1D line destroys information, and no amount of optimization can recover it with a purely linear model. The loss is substantial because the model is forced to make a compromise—it can’t simultaneously represent variation in both dimensions.

t.manual_seed(44)
device = get_device()
importances = t.ones(DEF_INPUT_DIM).to(device)
x = generate_data().to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=1500, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)

Experiment 2: Unequal importances, no sparsity#

Now we train with unequal importances: one feature matters more for our loss function than the other.

The model makes a rational trade-off. It dedicates its single degree of freedom—the direction of \(W\)—to representing the important feature as accurately as possible. The less important feature is approximated using the bias term \(b\), which essentially sets it to its mean value across the dataset. Since the data is uniform on [0,1], the unimportant dimension gets approximated as approximately 0.5.

This is still far from perfect. The model can represent one dimension well, but the other dimension shows no variation at all in the reconstruction. We’re still bound by the fundamental limitation: a one-dimensional representation can’t capture two-dimensional variation.

t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 0.1]).to(device)
x = generate_data().to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=3300, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)

Introducing sparsity#

We’ll now introduce sparsity into our data. With a sparsity parameter \(s = 0.8\), each feature independently has an 80% chance of being exactly zero.

This changes the structure of the problem. When \(x_1 = 0\) and \(x_2 \neq 0\), the projection \(h = Wx\) only captures information from \(x_2\). There’s no interference from \(x_1\) in the down-projection. Similarly, when \(x_2 = 0\) and \(x_1 \neq 0\), only \(x_1\) contributes to \(h\).

The question is whether a linear model can exploit this structure to achieve better reconstruction. With 80% sparsity, most of the time at least one of the features is zero. In these cases, the projection is effectively one-dimensional, which seems like it should make the reconstruction problem easier. But can gradient descent discover a way to take advantage of this?

Experiment 3: Sparse data, unequal importances#

With sparse data and unequal importances, the model behaves similarly to the non-sparse case from Experiment 2. It uses the direction of \(W\) to represent the more important feature, and uses the bias \(b\) to approximate the less important one.

The main difference is in the value of the bias. In the non-sparse case, the bias was around 0.5, since that’s the mean of a uniform distribution on [0,1]. Here, with 80% sparsity, most values are zero, so the mean is much closer to zero. The bias reflects this by taking a value closer to 0.

The sparsity helps reduce the overall loss, but the linear model still can’t do anything fundamentally different from what it did without sparsity. It’s still making the same basic trade-off: represent one dimension well, approximate the other with its mean.

t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 0.1]).to(device)
x = generate_data(sparsity=0.8).to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=2500, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)

Experiment 4: Sparse data, equal importances#

This configuration—sparse data with equal importances—turns out to be less stable than the others. The model sometimes gets stuck in local minima, and different training runs can produce quite different results.

The issue is that the model faces ambiguity. With equal importances, it has no clear signal about which dimension to prioritize. Should it commit fully to representing one dimension well, or should it try to hedge and partially represent both? The sparse structure of the data means that committing to one dimension could work well, but the symmetry in the loss function means there’s no gradient pushing it clearly in either direction.

As a result, you’ll often see solutions that don’t fully commit to either dimension, or training dynamics that oscillate between different strategies. This instability is itself informative: it shows that without additional structure (like unequal importances or, as we’ll see next, a nonlinearity), the model struggles to find a consistent solution even when the data has favorable properties like sparsity.

t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 1.0]).to(device)
x = generate_data(sparsity=0.8).to(device)
model = LinearMapping(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=10000, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)

Introducing nonlinearity#

We’ll now add a single nonlinearity to our model: a ReLU (Rectified Linear Unit) activation function. The ReLU is defined as \(\text{ReLU}(x) = \max(0, x)\)—it passes positive values through unchanged and clips negative values to zero.

This might seem like a small change, but it turns out to be crucial. The intuition is as follows: if we can arrange the parameters such that the two dimensions take opposite signs in the hidden representation \(h\), then the ReLU can act as a switch. When \(x_1 \neq 0\) and \(x_2 = 0\), we want \(h\) to reconstruct \(x_1\) correctly in its dimension while mapping to a negative value in the \(x_2\) dimension. The ReLU will then clip that negative value to zero, giving us the correct reconstruction. The same logic applies when \(x_2 \neq 0\) and \(x_1 = 0\).

This is sometimes called an “antipodal” configuration, because the two dimensions point in opposite directions in the hidden space. Whether gradient descent can discover this configuration depends on the structure of the data—specifically, on whether the sparsity is high enough that this strategy actually reduces the loss.

import torch as t

class ReLUModel(nn.Module):
    def __init__(self, input_dim: DEF_INPUT_DIM, hidden_dim: DEF_HIDDEN_DIM):
        super().__init__()
        W = t.empty(hidden_dim, input_dim)
        b = t.empty(input_dim)
        nn.init.xavier_normal_(W)
        nn.init.zeros_(b)
        self.W = nn.Parameter(W)
        self.b = nn.Parameter(b)

    def compute_hidden(self, x):
        return einops.einsum(x, self.W, "b f, h f -> b h")

    def compute_output(self, h):
        output = einops.einsum(h, self.W, "b h, h f -> b f") + self.b
        output = t.relu(output)
        return output

    def forward(self, x):
        h = self.compute_hidden(x)
        y = self.compute_output(h)
        return y

The ReLUModel class is nearly identical to our LinearMapping class. The only difference is in the compute_output method: after computing \(y = W^T h + b\), we apply the ReLU function before returning.

That single line—output = t.relu(output)—is the only change. Everything else about the model is exactly the same: same compression to one dimension, same reconstruction from one dimension back to two. But as we’ll see, this one additional operation changes what the model is capable of learning.

Experiment 5: Sparse data + ReLU#

This is the key result. With sparse data and a ReLU nonlinearity, the model converges to a configuration where \(W\) takes opposite signs for the two dimensions. For example, you might see \(W \approx [1, -1]\) or \(W \approx [0.7, -0.7]\) (the exact values depend on the random initialization).

Consider what happens in the forward pass. When \(x_1 \neq 0\) and \(x_2 = 0\):

  • The hidden value \(h = Wx\) has a component from \(x_1\) only

  • Reconstruction gives \(y = W^T h + b\), which will have a large positive value in the \(x_1\) dimension and a large negative value in the \(x_2\) dimension

  • The ReLU clips the negative \(x_2\) component to zero

  • Result: \(x_1\) is reconstructed accurately, and \(x_2 = 0\) as desired

The same logic applies when \(x_2 \neq 0\) and \(x_1 = 0\), just with the signs reversed.

When both features are nonzero simultaneously, the reconstruction requires a compromise—the model can’t perfectly separate them. But with 80% sparsity, each feature is zero 80% of the time, which means that at least one of the two is zero in 96% of cases (calculated as \(1 - 0.2 \times 0.2 = 0.96\)). The case where both are nonzero is rare enough that the overall MSE is dramatically lower than what a linear model achieves.

We’ve successfully represented two dimensions in a single number using only sparsity, positivity constraints, and a simple nonlinearity.

t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 1.0]).to(device)
x = generate_data(sparsity=0.8).to(device)
model = ReLUModel(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=10000, lr=3e-4, importances=importances)
plot_training_dashboard(x, metrics)

Experiment 6: Dense data + ReLU#

Now we remove the sparsity and go back to uniform data on the full unit square. Even with the ReLU nonlinearity, the model can’t find the antipodal solution that worked so well in the sparse case.

The reason is straightforward: when both dimensions are frequently nonzero, the ReLU clipping trick doesn’t help us. If we tried to use an antipodal configuration like \(W \approx [1, -1]\), we’d end up clipping away meaningful signal. When both \(x_1\) and \(x_2\) are positive and nonzero, the reconstruction would produce some combination of positive and negative values in both dimensions, and clipping the negative parts to zero would destroy information rather than helping us separate the dimensions.

The model falls back to a solution similar to what we saw with the purely linear model in Experiment 1. The ReLU alone isn’t sufficient—we need the sparse structure of the data for this approach to work.

t.manual_seed(44)
device = get_device()
importances = t.tensor([1.0, 1.0]).to(device)
x = generate_data(sparsity=0).to(device)
model = ReLUModel(input_dim=2, hidden_dim=1).to(device)
metrics = train(model, x, num_epochs=2500, lr=1e-3, importances=importances)
plot_training_dashboard(x, metrics)

Conclusion#

We started with an observation that seemed to violate basic principles of linear algebra: you can compress two independent dimensions into one number and still reconstruct both dimensions separately. We’ve now seen that this is possible when three ingredients are present: sparsity, positivity, and nonlinearity.

The key insight is that sparse data has structure that can be exploited. When features are zero most of the time, we can use different regions of the hidden space to represent different features. The ReLU nonlinearity acts as a switch, allowing us to map some features to negative values (which get clipped to zero) while preserving others. This only works because we know the data is non-negative—if features could be negative, we couldn’t use the ReLU clipping to separate them.

This has important implications for understanding neural networks. ReLU activations are ubiquitous in deep learning, and learned representations are often sparse. This means that neural networks can represent more features than they have neurons, given that multiple features can share the same neuron, with each feature being active in different contexts. This is the phenomenon of superposition, and it’s a core mechanism behind the parameter efficiency of neural networks.

Anthropic’s Toy Models of Superposition paper explores this phenomenon in much greater depth, examining cases with many more features and dimensions, and investigating the geometric structure of how features are represented. This notebook illustrates the core principle with the simplest possible case: two features compressed to one dimension. The same logic extends to more complex scenarios, but the fundamental insight remains the same: simple building blocks (linear maps and ReLUs) can implement sophisticated compression schemes when the data has the right structure.