Frequency-Domain Analysis and Fourier Transforms#

This notebook contains a few basic exercises to illustrate how Fourier-related transforms can help us analyze patterns in the frequency domain.

Almost all of the helper functions have been taken from Neel Nanda’s contribution to the ARENA materials with his study on grokking modular arithmetic.

Setup#

Hide code cell content
# If necessary, install requirements from repository root
# !pip install -r ../requirements.txt
import torch as t
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import os
import sys
import plotly.express as px
import plotly.graph_objects as go
from functools import *
from typing import List, Tuple, Union, Optional, Callable
from fancy_einsum import einsum
import einops
from jaxtyping import Float, Int
from tqdm import tqdm
from transformer_lens import utils
import pandas as pd

def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==t.Tensor:
        y = utils.to_numpy(y.flatten())
    if type(x)==t.Tensor:
        x=utils.to_numpy(x.flatten())
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    if x.ndim==1:
        fig.update_layout(showlegend=False)
    fig.show()


def scatter(x, y, title="", xaxis="", yaxis="", colorbar_title="", **kwargs):
    fig = px.scatter(x=utils.to_numpy(x.flatten()), y=utils.to_numpy(y.flatten()), title=title, labels={"color": colorbar_title}, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    if "xaxis_range" in kwargs:
        fig.update_xaxes(range=kwargs["xaxis_range"])
    if "yaxis_range" in kwargs:
        fig.update_yaxes(range=kwargs["yaxis_range"])
    fig.show()


def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==t.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==t.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def line_marker(x, **kwargs):
    lines([x], mode='lines+markers', **kwargs)


def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', title='', **kwargs):
    if type(lines_list)==list:
        lines_list = t.stack(lines_list, axis=0)
    lines_list = utils.to_numpy(lines_list)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[1]):
            rows.append([lines_list[i][j], snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])
    px.line(df, x=xaxis, y=yaxis, title=title, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs).show()

def imshow(tensor: t.Tensor, xaxis=None, yaxis=None, animation_name='Snapshot', vline_positions=[], vline_labels=[], hline_positions=[], hline_labels=[], animation_labels=[], **kwargs):
    tensor = t.squeeze(tensor)
    fig = px.imshow(utils.to_numpy(tensor), labels={'x': xaxis, 'y': yaxis, 'animation_frame': animation_name}, **kwargs)
    if animation_labels:
        for i, label in enumerate(animation_labels):
            fig.layout.sliders[0].steps[i]["label"] = label
    for x, text in zip(vline_positions, vline_labels):
        fig.add_vline(x=x-0.5, line_width=1, annotation_text=text, annotation_position="top left")
    for y, text in zip(hline_positions, hline_labels):
        fig.add_hline(y=y-0.5, line_width=1, annotation_text=text, annotation_position="top left")
    y_axis, x_axis = [s for i, s in enumerate(tensor.shape) if i != kwargs.get("animation_frame", None)]
    fig.update_yaxes(range=[y_axis-0.5, 0-0.5], autorange=False)
    fig.update_xaxes(range=[0-0.5, x_axis-0.5], autorange=False)
    fig.show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)

def imshow_fourier(tensor, title='', animation_name='snapshot', facet_labels=[], animation_labels=[], xlim=None, ylim=None, **kwargs):
    # Set nice defaults for plotting functions in the 2D fourier basis
    # tensor is assumed to already be in the Fourier Basis
    tensor = t.squeeze(tensor)
    fig=px.imshow(utils.to_numpy(tensor),
            x=fourier_basis_names, 
            y=fourier_basis_names, 
            labels={'x':'Horizontal Component', 
                    'y':'Vertical Component', 
                    'animation_frame':animation_name},
            title=title,
            color_continuous_midpoint=0., 
            color_continuous_scale='RdBu', 
            **kwargs)
    fig.update(data=[{'hovertemplate':"%{x}x * %{y}y<br>Value:%{z:.4f}"}])
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if animation_labels:
        for i, label in enumerate(animation_labels):
            fig.layout.sliders[0].steps[i]["label"] = label
    if ylim is not None:
        fig.update_yaxes(range=ylim, autorange=False)
    if xlim is not None:
        fig.update_xaxes(range=xlim)
    fig.show()


def animate_multi_lines(lines_list, y_index=None, snapshot_index = None, snapshot='snapshot', hover=None, swap_y_animate=False, **kwargs):
    # Can plot an animation of lines with multiple lines on the plot.
    if type(lines_list)==list:
        lines_list = t.stack(lines_list, axis=0)
    lines_list = utils.to_numpy(lines_list)
    lines_list = lines_list.transpose(2, 0, 1)
    if swap_y_animate:
        lines_list = lines_list.transpose(1, 0, 2)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if y_index is None:
        y_index = [str(i) for i in range(lines_list.shape[1])]
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    # print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append(list(lines_list[i, :, j])+[snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=y_index+[snapshot, 'x'])
    px.line(df, x='x', y=y_index, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover, **kwargs).show()


def animate_scatter(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, yaxis='y', xaxis='x', color=None, color_name = 'color', **kwargs):
    # Can plot an animated scatter plot
    # lines_list has shape snapshot x 2 x line
    if type(lines_list)==list:
        lines_list = t.stack(lines_list, axis=0)
    lines_list = utils.to_numpy(lines_list)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    if color is None:
        color = np.ones(lines_list.shape[-1])
    if type(color)==t.Tensor:
        color = utils.to_numpy(color)
    if len(color.shape)==1:
        color = einops.repeat(color, 'x -> snapshot x', snapshot=lines_list.shape[0])
    # print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append([lines_list[i, 0, j].item(), lines_list[i, 1, j].item(), snapshot_index[i], color[i, j]])
    # print([lines_list[:, 0].min(), lines_list[:, 0].max()])
    # print([lines_list[:, 1].min(), lines_list[:, 1].max()])
    df = pd.DataFrame(rows, columns=[xaxis, yaxis, snapshot, color_name])
    px.scatter(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_x=[lines_list[:, 0].min(), lines_list[:, 0].max()], range_y=[lines_list[:, 1].min(), lines_list[:, 1].max()], hover_name=hover, color=color_name, **kwargs).show()


if t.backends.mps.is_available():
    device = t.device('mps')
elif t.cuda.is_available():
    device = t.device('cuda')
else:
    device = t.device('cpu')

Introduction: the Fourier transform and the frequency domain#

The Fourier transform tells us what frequencies make up a given sequence. We can think of it a basis change—it takes a sequence of length N, applies a linear transformation, and represents it as N new numbers, which we call Fourier coefficients. The Fourier coefficients represent the amplitudes of a series of periodic functions, with frequencies at integer multiples of \(\frac{2\pi}{N}\), that can fully reconstruct the original sequence.

Mathematically speaking, the Fourier transform of sequence \(x\) of length \(N\) is a sequence \(X\) of the same length, for which each element \(X_k\) is the result of projecting \(x\) onto the \(k\)-th basis vector \(b^k\), where \(b^k\) is a complex sinusoid at frequency \(\omega_k = 2\pi \frac{k}{N}\).

\[\begin{split} b^k = e^{-j2\pi \frac{k}{N}n} = \cos(2\pi \frac{k}{N}n) - j \sin(2\pi \frac{k}{N}n) \\\ \\\ X_k = \sum_{n=0}^{N-1}x_n \cdot b_n^k = \sum_{n=0}^{N-1}x_n \cdot e^{-j2\pi \frac{k}{N}n} \end{split}\]

Using complex sinusoids as basis vectors leads the Fourier transform to exhibit a number of very interesting properties, such as turning input shifts into phase shifts and convolutions into element-wise products. You can read about them at length here.

For the purposes of this notebook, however, we’re mainly interested in the broader notion of using a basis of periodic vectors to go from an “input domain” to a “frequency domain”. In other words, our key takeaway is that we can measure how strongly an input sequence \(x\) is periodic at a set of frequencies \(\omega_k\) by defining a set of basis vectors \(b^k\), each of which is a periodic sequence at frequency \(\omega_k\), and projecting \(x\) onto the basis vectors \(b^k\). This is captured by the more general expression below, where the vectors \(b^k\) are not necessarily complex sinusoids, but rather some set periodic sequences at frequencies \(\omega_k\).

\[ X_k = \sum_{n=0}^{N-1}x_n \cdot b_n^k \]

We can think of this sum of element-wise products as a matrix product, and rewrite the transform as:

\[ X = B \cdot x^T \]

Where \(x\) is a row vector containing the input sequence, and \(B\) is a matrix whose rows are the basis terms \(b^k\):

\[\begin{split}x = \begin{bmatrix} x_0 & ... & x_N \end{bmatrix} \\\ \\\ B = \begin{bmatrix} b^0\\ ...\\ b^N \end{bmatrix} = \begin{bmatrix} b^0_0 & ... & b^0_N \\ % b^1_0 & ... & b^1_N \\ \vdots & \ddots & \vdots \\ b^N_0 & ... & b^N_N \\ \end{bmatrix} \end{split}\]

The \(k\)-th element of \(X\) will be the dot product between the \(k\)-th row of \(B\) (that is, \(b^k\)) and the input sequence \(x\).

In the rest of this notebook, we’ll be referring to \(B\) as the “Fourier basis” and to the vectors \(b^k\) as “Fourier basis terms” even when they are not the complex sinusoids used in the actual formulation of the Fourier transform, but rather some other set of periodic sequences of increasing frequency. We’ll also refer to the resulting set of coefficients \(X = B \cdot x^T\) as the “Fourier coefficients”, and to the vector space they occupy as the “Fourier domain”. This is an abuse of terminology, but it helps to simplify our explanation and still suggests the correct intuition for the purposes of understanding frequency-domain analysis.

In the next section, we’ll see how we can create a new set of Fourier basis vectors using sine and cosine functions.

Frequency analysis in 1D#

Defining a basis of cosine and sine terms#

A set of basis vectors that still performs the desired task of capturing periodic patterns at increasing frequencies, but that saves us the inconvenience of dealing with complex numbers, can be defined as follows:

\[\begin{split} B (N \times N) = \begin{bmatrix} \leftarrow 1 \rightarrow\\ \leftarrow \cos(1 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \sin(1 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \cos(2 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \sin(2 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] ...\\[6pt] \leftarrow \cos(\frac{N}{2} \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \sin(\frac {N}{2} \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \end{bmatrix} \end{split}\]

The matrix has \(N\) rows, each of length \(N\). The first row is the constant value 1, which will capture the constant components of the input sequences. The rest of the rows are alternating cosine and sine terms, each of length \(N\), for frequencies at integer multiples \(\omega\) of \(2\pi\frac{1}{N}\), with \(\omega\) between \(1\) and \(\frac{N}{2}\).

The code below generates this 1D Fourier basis for a given sequence length N. As explained above, the resulting matrix can be thought of as a constant term plus N//2 sine and cosine terms respectively, each of length N, with frequencies between 1 and N//2. Here, N//2 refers to the rounding down of N/2, i.e. 12//6 = 2 and 13//6 = 2. Every row is divided by its own norm in order to preserve the scale of the sequences we’ll project onto it.

The interactive plot allows you to view each of the basis terms, that is, each of the length-N rows of the basis matrix. Note that the play button will animate the transitions between basis terms in weird ways, so it might be easiest to simply drag the slider.

N = 128 
def make_fourier_basis(N: int) -> Tuple[Tensor, List[str]]:
    '''
    Returns a pair `fourier_basis, fourier_basis_names`, where `fourier_basis` is
    a `(N, N)` tensor whose rows are Fourier components and `fourier_basis_names`
    is a list of length `N` containing the names of the Fourier components (e.g.
    `["const", "cos 1", "sin 1", ...]`).
    '''
    fourier_basis = t.ones(N, N)
    fourier_basis_names = ['Const']

    for i in range(1, N // 2 + 1):
        # Define each of the cos terms
        fourier_basis[2*i-1] = t.cos(2*t.pi*t.arange(N)*i/N)
        fourier_basis_names.append(f'cos {i}')

        # Define each of the sin terms, excluding the last one if p is even
        if 2*i < N:
            fourier_basis[2*i] = t.sin(2*t.pi*t.arange(N)*i/N)
            fourier_basis_names.append(f'sin {i}')

    # Normalize vectors, and return them
    fourier_basis /= fourier_basis.norm(dim=1, keepdim=True)
    return fourier_basis.to(device), fourier_basis_names


fourier_basis, fourier_basis_names = make_fourier_basis(N)

animate_lines(
    fourier_basis, 
    snapshot_index=fourier_basis_names, 
    snapshot='Fourier Component', 
    title=f'Fourier basis terms for N={N}'
)

A side-note on rounding down N//2#

This rounding down introduces an annoying little problem for us to work around. Under our particular definition of the transform, there will in theory be one constant term, N//2 sine terms, and N//2 cosine terms. This means that there will be exactly N terms if N is odd, but N+1 terms if N is even. For example:

  • If N = 129, there will be one constant term, N//2 = 64 cosine terms, and N//2 = 64 sine terms, for a total of N = 129 basis terms.

  • If N = 128, there will be one constant term, N//2 = 64 cosine terms, and N//2 = 64 sine terms, for a total of N+1 = 129 basis terms.

In practice, as you can see by inspecting the code above, if N is even, we discard the final sine term. Why is this?

The final sine term is \(\sin(\frac{N}{2}\cdot 2\pi \frac{1}{N} n)\), which simplifies down to \(\sin(\pi n)\). This is seemingly a perfectly valid function, much like its corresponding cosine term, \(\cos(\pi n)\), which we’re not discarding—they are both sinusoidal sequences that complete a whole oscillation every two elements of the sequence. Why would we discard the sine term?

If we were dealing with continuous sequences, i.e. if we were calculating \(\sin(\pi x)\) and \(\cos(\pi x)\) for a continuous \(x\) taking every real value between \(0\) and \(N\), including for example \(0.12345\), this would be fine. However, we’re dealing with \(\sin(\pi n)\), where \(n\) is a discrete variable that takes integer values between 1 and N. This means, effectively, that we’re sampling the functions \(\sin(\pi x)\) and \(\cos(\pi x)\) at integer values of \(x\). The cosine of integer multiples of \(\pi\) will alternate between the values \(1\) and \(-1\) (as you can see in the plot for the final cosine term above), but the sine of integer multiples of \(\pi\) will always be zero.

In other words: it’s not that we arbitrarily decide that don’t care about the final sine term, nor is it the case that \(\sin(\pi x)\) is always zero. However, it is true that \(\sin(\pi x)\) is always zero if we always evaluate it at integer values of \(x\). It’s very unfortunate for the final sine term that we always happen to measure it when it’s crossing zero, but it’s actually handy for us since it means that we still get exactly N basis terms, even if N is even.

The plot below illustrates this. It shows the functions that the final cosine and sine terms sample from, i.e. \(\sin(\pi x)\) and \(\cos(\pi x)\), zoomed in to the range \((0,10)\) in order to show a few cycles. As one would expect from cosine and sine functions, they are identical except for a \(\frac{\pi}{2}\) offset in the \(x\) axis.

Because of this offset, if we sample both functions at integer values of \(x\), we will always be sampling the cosine term at its maximum and minimum values of \(1\) and \(-1\), whereas we will always be sampling the sine term when its value is zero.

x_density = 100
x = t.linspace(0, N, x_density*N)
k = N//2
cosx = t.cos(k * 2 * t.pi / N * x)
sinx = t.sin(k * 2 * t.pi / N * x)
n_lim = 10
x_lim = n_lim * x_density

lines(
    [cosx[:x_lim], sinx[:x_lim]],
    x=x[:x_lim],
    labels=['Final cosine term', 'Final sine term'],
    xaxis='x',
    yaxis='Amplitude', 
    title='Final cosine and sine terms for continuous variables'
)

cosn = t.cos(k * 2 * t.pi / N * t.arange(n_lim+1))
sinn = t.sin(k * 2 * t.pi / N * t.arange(n_lim+1))

lines(
    [cosn, sinn],
    labels=['Final cosine term', 'Final sine term'],
    xaxis='n',
    mode='markers',
    yaxis='Amplitude', 
    title='Final cosine and sine terms, sampled at discrete integer variables'
)

Calculating the 1D Fourier transform#

Recall that we can express the Fourier transform as a matrix product:

\[ X = B \cdot x^T \]

Where \(x\) is a row vector containing the input sequence, and \(B\) is a matrix whose rows are the basis terms \(b^k\):

\[\begin{split}x = \begin{bmatrix} x_0 & ... & x_N \end{bmatrix} \\\ \\\ B (N \times N) = \begin{bmatrix} \leftarrow 1 \rightarrow\\ \leftarrow \cos(1 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \sin(1 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \cos(2 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \sin(2 \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] ...\\[6pt] \leftarrow \cos(\frac{N}{2} \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \leftarrow \sin(\frac {N}{2} \cdot 2\pi\frac{1}{N}n) \rightarrow\\[6pt] \end{bmatrix} \end{split}\]

This makes calculating the Fourier coefficients deceptively simple! The function below implements it in just two lines.

Note: Again, we’re using the term “FFT” very loosely here.

def fft1d(x: t.Tensor) -> t.Tensor:
    '''
    Returns the 1D Fourier transform of `x`,
    which can be a vector or a batch of vectors.

    x.shape = (..., p)
    '''
    basis, _ = make_fourier_basis(x.shape[-1])
    return basis.to(device) @ x.to(device)

Let’s create a test sequence and see what the squared magnitude of its Fourier transform looks like.

n = t.arange(N)
const = 0.5
a1 = 0.3
w1 = 20
a2 = 0.6
w2 = 13
x = const + a1 * t.cos(w1 * 2*t.pi/N * n) + a2 * + t.sin(w2 * 2*t.pi/ N * n)

lines(
    [x],
    labels=['Test sequence'],
    xaxis='n',
    yaxis='Amplitude', 
    title='Test sequence in the input domain'
)

line(
    fft1d(x).pow(2),
    hover=fourier_basis_names,
    xaxis='k',
    yaxis='Amplitude', 
    title='Test sequence in the frequency domain'
)

As you can see, what is a visibly periodic but somewhat messy-looking sequence in the input domain, results in a highly sparse and clean sequence of Fourier coefficients. Indeed, by hovering over the peaks you can see that they capture exactly the functions and frequencies we used to generate the test sequence.

A musical example#

Out of all the different kinds things in the world that have distinct frequency components, musical notes are by far the best! Let’s use this to develop an intuition for what the Fourier transform captures.

We’ll start by loading a recording of a flute, which was posted on Freesound by the user juskiddink. The description reads “two clay bird flutes (different sizes) placed in the mouth and blown together, giving two notes”. Let’s see if we can find two Fourier peaks, one for each note!

Note that, in this case, “two notes” can be ambiguous, since the flutes are actually playing two notes simultaneously, and then doing it again. What interests us is the fact that two notes are played simultaneously, i.e. that there are two flutes being played at the same time, each at a different pitch. We’ll use the Fourier transform to identify those pitches.

import torchaudio
import IPython.display as ipd
import os

os.environ['TORCHAUDIO_USE_BACKEND_DISPATCHER'] = '0'
wav, sr = torchaudio.load('flute.wav')
ipd.Audio(wav.detach().numpy(), rate=sr)

The Fourier transform implicitly assumes that the input sequence is stationary, that is, that the intensity of each frequency component does not change throughout the length of the sequence. For this reason, and in order to limit the size of our operands (at the sample rate of this file, one second of audio contains \(44,100\) discrete floating-point values), we will take the Fourier transform of a small 50-milisecond slice of the sound. By trial and error we end up finding that the fragment between the 150ms and 200ms timestamps seems to be fairly stationary.

# Keep from 150ms in to 200ms in, for the left channel only
x_min = int(sr*0.15) 
x_max = int(sr*0.2)
wav_sample = wav[0, x_min:x_max]


lines(
    [wav_sample],
    labels=['Test sequence'],
    xaxis='n',
    yaxis='Amplitude', 
    title='Input waveform'
)

ipd.Audio(wav_sample.detach().numpy(), rate=sr)

Let’s take its Fourier transform and see what we find!

wav_N = wav_sample.shape[0]
wav_basis, wav_basis_names = make_fourier_basis(wav_N)
wav_spectrum = fft1d(wav_sample).pow(2)
max_k = 300
line(
    wav_spectrum[:max_k],
    hover=wav_basis_names[:max_k],
    xaxis='Fourier basis term',
    yaxis='Amplitude', 
    title='Fourier transform of the input waveform'
)

That’s pretty clean! Two frequencies, as promised.

Without going into too much detail about sampling rates, we can calculate the frequencies that correspond to the Fourier indices of these peaks using the formula \(f = k \frac{F_s}{N}\), where \(F_s\) is the sampling rate of our audio signal, in this case \(44,100\), and \(k\) is the index of a given Fourier peak.

Considering our peaks occur for the cosine term at k=25 and the sine and cosine terms at k=32, let’s calculate what notes the flutes were playing!

Note: the frequency_to_note is a helper function that isn’t essential for you to understand—it’s based on the way musical notes map to frequencies along a logarithmic scale.

import math
def frequency_to_note(frequency):
    # Define the reference frequency for A4 (440 Hz)
    A4_frequency = 440
    # Calculate the number of semitones away from A4
    semitones = round(12 * math.log2(frequency / A4_frequency))
    # Define the list of note names
    note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
    # Calculate the octave and note index
    octave = semitones // 12 + 4
    note_index = semitones % 12
    # Return the note name and octave
    return note_names[note_index] + str(octave)

f1 = 25 * sr / wav_N
note1 = frequency_to_note(f1)
f2 = 32 * sr / wav_N
note2 = frequency_to_note(f2)

print(f'F1 = {f1} Hz ({note1})')
print(f'F2 = {f2} Hz ({note2})')
F1 = 500.0 Hz (D4)
F2 = 640.0 Hz (F#4)

Very nice! If you play music and have an instrument nearby, try playing a D4 and an F#4 (a nice and bright major third) and you’ll see that it’s quite close to the recording. We’re already halfway to building a tuner :)

Frequency analysis in 2D#

By now, we have a decent understanding of what the 1D Fourier transform does: it captures the amplitudes of periodic components in the input sequence, at different frequencies. But 2D sequences, which we can think of as images, can also exhibit periodic properties, right?

The 2D Fourier transform can be understood by analogy to the 1D Fourier transform:

  • Instead of 1-D vectors \(x_n\) of length \(N\), our inputs will be 2-D tensors \(x_{n,m}\) of shape \((N,N)\). For simplicity, we’ll refer to the first and second dimensions as “horizontal” and “vertical” components.

  • Instead of happening only along one input dimension \(n\), periodic patterns can occur along two dimensions. We’ll therefore have, instead of a 1-D vector of Fourier coefficients \(X_k\), a 2-D matrix of Fourier coefficients \(X_{k,l}\).

  • In the 1-D transform, \(X_k\) captured the periodic component of the input at frequency \(k\). In the 2-D Fourier transform, \(X_{k,l}\) captures the periodic component of input \(x_{m,n}\) at the 2-D frequency given by \(k\) in the \(m\) (horizontal) dimension and \(l\) in the \(n\) (vertical) dimension.

  • In the 1-D Fourier transform, for each frequency \(k\) we wanted to capture, we had a basis term \(b^k\) of the same length as the input \(x\), which was a periodic function at frequency \(k\). The coefficient \(X_k\) was the dot product between the input \(x\) and basis term \(b^k\). Similarly, in the 2-D Fourier transform, for each frequency pair \((k, l)\) we want to capture, we will have a basis term \(b^{k, l}\) of the same shape as the input, which will be a periodic function at frequencies \(k\) (horizontal) and \(l\) (vertical). The coefficient \(X_{k,l}\) will be the dot product between the input \(x\) and the basis term \(b^{k,l}\).

Having understood this, and using our implementation of the 1-D Fourier transform, the implementation is fairly straightforward:

  • For every pair \((k,l)\) of horizontal and vertical frequencies, we generate a 2-D basis term as the outer product of the 1-D Fourier basis terms for \(k\) (as a row vector) and \(l\) (as a column vector).

  • We compute the 2-D FFT at position \((k,l)\) as the dot product between our 2-D input and the 2-D basis term generated for \((k, l)\).

As you can see from the code below, we can actually use a single line of einsum magic to perform both steps in one. We’ll keep the function to generate 2-D basis terms nevertheless, as it’ll come in handy to generate test images.

def fourier_2d_basis_term(fourier_basis, k: int, l: int) -> Float[Tensor, "N N"]:
    '''
    Returns the 2D Fourier basis term corresponding to the outer product of the
    `k`-th component of the 1D Fourier basis in the `x` direction and the `l`-th
    component of the 1D Fourier basis in the `y` direction.

    Returns a 2D tensor of length `(N, N)`.
    '''
    fourier_basis = fourier_basis.to('cpu')
    return (fourier_basis[l][:, None] * fourier_basis[k][None, :])

def fft2d(tensor: t.Tensor) -> t.Tensor:
    '''
    Retuns the components of `tensor` in the 2D Fourier basis.

    Asumes that the input has shape `(N, N, ...)`, where the
    last dimensions (if present) are the batch dims.
    Output has the same shape as the input.
    '''
    # fourier_basis[k] is the k-th basis vector, which we want to multiply along
    N = tensor.shape[0]
    fourier_basis, fourier_basis_names = make_fourier_basis(N)
    return einops.einsum(
        tensor.cpu(), fourier_basis.cpu(), fourier_basis.cpu(), "pn pm ..., k pn, l pm -> k l ..."
    )

We’ll start by having a look at some of the basis terms:

N = 64
fourier_basis, fourier_basis_names = make_fourier_basis(N)

k = 0
l = 1
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

k = 1
l = 0
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

k = 1
l = 1
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

k = 7
l = 3
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

Let’s test our 2-D FFT by generating a sample 2D input as the combination of a handful of basis terms:

N = 64
fourier_basis, fourier_basis_names = make_fourier_basis(N)
example_fn = sum([
    4* fourier_2d_basis_term(fourier_basis, 4, 6), 
    7* fourier_2d_basis_term(fourier_basis, 14, 46),
    8* fourier_2d_basis_term(fourier_basis, 30, 50)
])

imshow(example_fn.T, title=f"Example periodic function")

imshow_fourier(
    fft2d(example_fn),
    title='Example periodic function in 2D Fourier basis'
)

As you can see, what is a dense and fairly complex 2D pattern in the input domain becomes an incredibly sparse pattern in the 2D Fourier domain. We can easily spot the horizontal and vertical frequencies that generated the image as the only non-zero coefficients in the Fourier domain.

Conclusion#

That’s a wrap! Hopefully this notebook has helped solidify some notions of what the Fourier transform is, how it is calculated, what it represents, and how it can be used to analyze sequences and images that exhiibit periodic patterns.

If you’d like to learn more about the Fourier transform, I’d highly recommend checking out this incredible video by the legendary 3Blue1Brown.