Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Structure and Execution of Language Models

This textbook accounts for every mathematical operation executed by a language model, expressed in both mathematical notation and code. Beyond the text itself, this project has two software components:

  1. Catform, a domain-specific language for expressing tensor computations
  2. Pianola, an engine that executes catform programs

AI programs today broadly have two layers—an outer scaffolding layer, and an inner layer, which is the language model itself. The book’s first chapter Inference covers in full detail the core logic of scaffolding—converting context into tokens (a unit of language) and then autoregressively sampling the language model to generate a response, one token at a time. The chapter also references Pianola’s implementation in src/pianola/inference.py line by line.

The language model itself is a purely deterministic mathematical function—it can be described entirely as a sequence of tensor operations. In this text, we write language models in catform—short for categorical form, inspired by category theory—a notation designed from first principles to mirror the algebra of these tensor operations. Writing the model in a custom language, rather than in a Python tensor framework, isolates its mathematical description as a standalone artifact: a single .cat file.

Catform is a minimal language: it uses just six primitive tensor operations to express the full computation of any modern transformer-based language model—and the resulting specification is no more verbose than the Python it replaces. The book’s second chapter Tensors is an interlude on the mathematics of tensor computations, expressed as executable catform. This set of operations is also closed under differentiation: the derivative of any catform program is itself a catform program. The upcoming Gradients chapter develops this property by implementing this transformation.

The third chapter Models walks through the complete mathematical description of Qwen3, a popular and representative open-weight language model. The description is expressed as a single .cat file, which Pianola can execute with either PyTorch or JAX without any modification to the model code.

Modern language models share the same core transformer architecture, varying in only a handful of components. The upcoming Architecture chapter will show how variants—mixture of experts, compressed latent attention, linear attention—differ from the baseline by substituting a small number of functions in the .cat file.

A model is determined by both its architecture—its sequence of operations—and by its weights, numerical parameters (counting in the billions, and more recently trillions) that parameterize these operations. These weights are not programmed, but rather discovered by searching for the model that minimizes a loss function—a measure of its prediction error on a giant corpus of data. An optimizer performs the search, stepping through the space in the direction of the loss gradient—computed in the aforementioned Gradients chapter. The upcoming Training chapter walks through this process.

Philosophy

Our title is a direct homage to the seminal programming textbook Structure and Interpretation of Computer Programs. The debt runs deeper than the name:

The general technique of isolating the parts of a program that deal with how data objects are represented from the parts of a program that deal with how data objects are used is a powerful design methodology called data abstraction.

— SICP, Section 2.1

The relationship between catform and Pianola is an instance of what SICP calls data abstraction. The .cat file is a declarative artifact: a piece of data that describes what to compute, without prescribing how. The execution layer—which can lower to PyTorch, JAX, or in principle to any target hardware backend—consumes that data and runs it. Portability across frameworks is not an independent feature but a consequence of the abstraction: neither layer knows the other’s details. This is precisely the separation named in the title: the structure is the catform specification, and its execution is Pianola playing it like a piano roll.

Organization

Both the book and the code are focused entirely on correctness and clarity, without attention to performance. Future chapters may focus on making all of these systems fast.

All six chapters were described above. The first three are available now.

  1. Inference covers the outer loop for running inference on a model
  2. Tensors is a mathematical interlude on tensor computations
  3. Models covers the complete mathematical description of a modern language model (Qwen3)

The subsequent three will be released soon.

  1. Architecture will survey architectural variants—mixture of experts (Qwen3 MoE), compressed latent attention (DeepSeek V3), and linear attention—showing how each modifies the baseline transformer
  2. Gradients will be a mathematical interlude on computing the derivative of a tensor computation
  3. Training will cover the components of training—the training loop, loss functions, and optimizers—in both supervised and reinforcement learning settings

The dependency graph of the chapters is as follows.

Audience

This book is primarily for two kinds of reader. The first is the programmer—especially one working in or around AI—who wants a complete understanding of the mathematical structure of language models. The second is the student of mathematics, who may be distant from programming, but desires an entry point into AI that speaks in their language and meets their standard of precision.

The primary programming prerequisite for this book is a basic understanding of Python—specifically functions, types—including enums, dataclasses, and containers like tuples and dicts—and basic control flow. Python is universally used in AI engineering, and we follow suit. The primary mathematical background is a basic understanding of linear algebra—not a full course, just familiarity with vectors, linear maps, matrices, and the notion of dimension. Later on, in the chapter on gradients, the reader will want familiarity with differential calculus—and single-variable suffices.

Both traditions share the same elemental concepts: things (terms or elements), collections of things (types or sets), and transformations between them (functions). Since this spine is shared, we need only establish conventions: mathematicians speak of sets and their elements; programmers speak of types and their terms. We use the latter throughout, noting that for our purposes the theoretical distinctions between the two foundations are not relevant. Throughout this text, most concepts are stated twice: once as a mathematical expression and once as a snippet of working code. This is not redundancy—it is the point. The two notations say the same thing to different readers—including the one we call a computer—and seeing them side by side reveals that the distance between blackboard and terminal is shorter than it appears.

Notation

In mathematical notation, we declare that a term is of type by writing:

In Python and catform alike, a type annotation looks like:

x: X

Given types and , a function takes an input and returns an output . In mathematical notation:

In Python:

def f(x: X) -> Y: ...

In catform:

f(x: X) -> (y: Y) {...}

Given two functions and , their composition is the function that applies first and then . In mathematical notation, we represent this either as an arrow diagram:

or as a binary operator. The classical notation writes composition as —read aloud as “ of ”, reading right to left. We often prefer to match the diagrammatic order and use forward composition, written and read aloud as “ then ”.

In both Python and catform, our convention for readability is to express composition line by line:

y = f(x)
z = g(y)

while in catform it is the only way one can express it—each line is always precisely one single value assignment to the output of an operation.

y: Y = f(x)
z: Z = g(y)

Setup

The book is more useful—and more fun—when you can see the math actually running on your computer. Setup instructions are in the README.

Inference

What exactly happens when an AI runs? AI systems—at least those based on language models—may have all sorts of extra scaffolding that turn them into chatbots, structured workflows, or autonomous agents. At their core, however, all of these systems are performing inference—i.e. making calls to a single function that takes in some context and generates a response. Contexts and responses may carry all sorts of structure (as we will see in the later Templates section), but in their simplest form they are both just text, and hence inference is just a function that takes text input and generates text output:

The core of this book—its first three chapters—will successively deconstruct this function until we have a completely unambiguous mathematical description of the entire inference process. As mentioned in the introduction, the book will parallel the code in Pianola, and in this chapter we will follow the specific module src/pianola/inference.py, which implements the above function:

def generate(
    tokenizer: Tokenizer,
    sampler: Sampler,
    model: Model,  
    context: str,
) -> str:
    """generate = encode ; complete ; decode."""
    tokens = tokenizer.encode(context).ids        # encode  : str -> list[int]
    completion = complete(sampler, model, tokens) # complete: list[int] -> list[int]
    return tokenizer.decode(completion)           # decode  : list[int] -> str

The first thing one will notice is that this function has several extra arguments beyond the context string. If one fills in tokenizer, sampler, and model with actual values, then one gets a function as above: i.e. one that takes a context and outputs a response. By far the most central of these is the model—it is the AI—the rest is just plumbing. In this chapter, we will deconstruct all of this plumbing. Then, after a mathematical interlude on tensor algebra in the next chapter, we will spend the third chapter on the model itself.

Tokenization

What exactly is the set ? It is natural to think of a text as a list of items of some given unit, which we call a token. To make this conversion, we will need a tokenizer, which consists of several components. Firstly, it is a vocabulary, i.e. an enumerated list of tokens, along with an encoding function that decomposes text into a list of token indices

We call the number of tokens the vocabulary size and typically denote it with the symbol . In principle, one could make many different choices of tokenization scheme. The main tradeoff to consider is that between vocabulary size and the size of the tokenized output. Two naive choices for a vocabulary are characters and words. Using characters yields a small vocabulary, but produces very long token sequences. Using words shortens sequences, but leads to an impractically large (effectively unbounded) vocabulary.

Nearly all modern language models—including every model in this book—use a middle-ground approach called byte pair encoding (BPE). The procedure for training a BPE tokenizer is as follows. The vocabulary is initialized with individual bytes—a level below characters, since a single character may comprise multiple bytes under UTF-8 encoding. Then, using a reference text corpus, the most frequently occurring adjacent pair of tokens is iteratively merged into a single new token until the vocabulary reaches a target size. We save both the vocabulary and the ordered sequence of merge rules. This core loop is simple enough to describe in a few lines of pseudocode:

vocabulary = {all single bytes}
merge_rules = []
while len(vocabulary) < target_size:
    (a, b) = most_frequent_adjacent_pair(corpus)
    vocabulary.add(a + b)
    merge_rules.append((a, b))

Given a trained tokenizer, the function is implemented as the following sequence:

  1. the input text is first normalized to a canonical unicode form
  2. to preserve semantic boundaries, it is then split into chunks by a regex
  3. the merge rules are applied in order within each chunk

Different model families train their tokenizers on different reference corpora and to different vocabulary sizes, but the underlying algorithm is the same. As mentioned in the introduction, we will follow Qwen3 as a running example of a modern language model. Qwen3’s tokenizer—whose vocabulary has 151,665 tokens—produces the following example tokenizations:

InputTokens
"Hello, world!"Hello , ·world !
"The quick brown fox"The ·quick ·brown ·fox
"tokenization"token ization
"transformer"transform er
"x = 3.14"x ·= · 3 . 1 4

Note the · marks—these are just leading spaces—since BPE operates on bytes, spaces are included just like any other bytes. Common words survive as single tokens; rarer words are split into recognizable subwords; numbers and punctuation are handled byte by byte.

We can then define a decoding function

by simply reversing the above process:

  1. each index is mapped back to its vocabulary entry
  2. the entries are concatenated
  3. the resulting bytes are decoded as text

The round-trip recovers the original text up to unicode normalization:

In current practice, all of the rules for both encoding and decoding are packaged in a single JSON file—tokenizer.json—distributed alongside each model. The reader is encouraged to open models/qwen3/tokenizer.json and inspect its top-level keys directly:

  • version—the format version (1.0)
  • truncation, padding—rules for trimming or padding token sequences to a fixed length; unused by Qwen3 (both null)
  • normalizer—a string code naming the unicode normalization form (NFC)
  • pre_tokenizer—the regex pattern that splits text into chunks before merging
  • post_processor—rules for inserting special tokens after encoding
  • decoder—a string code specifying how vocabulary entries are mapped back to raw bytes (ByteLevel)
  • added_tokens—special tokens beyond the BPE vocabulary; includes the end-of-sequence token <|endoftext|> and template delimiters discussed in the Templates section
  • model—the core: the BPE vocab and merges trained above

Instead of implementing encoding and decoding by hand, we use the industry-standard tokenizers library from Hugging Face, which provides a compiled Rust binary that reads a tokenizer.json file and exposes encode and decode functions.

We have now described the first and last line of the generate function:

"""generate = encode ; complete ; decode."""
tokens = tokenizer.encode(context).ids        # encode  : str -> list[int]
completion = complete(sampler, model, tokens) # complete: list[int] -> list[int]
return tokenizer.decode(completion)           # decode  : list[int] -> str

This means that the problem of defining a function from text to text has been reduced to defining a so called completion function: one that takes a list of tokens and produces a list of tokens:

In mathematical notation, we can depict this composition in the following commutative diagram

This diagram visually expresses that computing by going straight across from to is equal to going down, across, and then back up via , , and then finally . We now turn to defining the completion function.

Completion

While there are other sorts of language models—e.g. diffusion language models—the vast majority of those in circulation at the time of writing (2026) are auto-regressive, meaning that they complete the response sequence one token at a time. Thus the function might be thought of as iterative application of a more elemental function which computes the next token

We can then define the completion of a token sequence as the successive applications of , performed while some “keep generating” condition holds:

In pseudocode one may write this as the loop:

response = []
while keep_generating:
    token = next_token(context+response)
    response.append(token)
return response

The above however is only partially correct. You may have heard that language models are, in contrast to classic computation, “probabilistic” rather than “deterministic”. Technically this is not exactly correct: there is a clean separation of the deterministic component—which, somewhat ironically, is the part we actually refer to as the model—and the probabilistic component, which samples from the model output. This leads us to the actual implementation of the complete function:

def complete(
    sampler: Sampler,
    model: Model,
    context: list[int]
) -> list[int]:
    """Autoregressive completion"""
    response = []
    while sampler.keep_generating(response):
        logits = model(context + response)
        token = sampler.sample(logits)
        response.append(token)
    return response

The function body is almost identical to the above loop code, with the exception of the presence of sampler, which packages the keep_generating condition above along with a function to actually sample from the model. The function model—which, to reiterate, is the language model itself—takes as input the list of tokens in the context, and outputs a numerical score—called a logit—for each token in the vocabulary. For the time being, we can represent these output logits as having the type of a real-valued array of size , where the value of the array at index is the logit associated to the token.

This function is parameterized by a large collection of real numbers called weights. A model’s advertised size—0.6B, 30B, 235B parameters—refers to the count of these values. The weights are not programmed—they are discovered, by a procedure called training that searches for the values that minimize prediction error on a vast corpus of text. As Karpathy puts it in Software 2.0: “No human is involved in writing this code.” Once training is complete, the weights are frozen, and the model becomes a fixed, deterministic mathematical function.

Softmax

We now turn to the question of how we sample from the logits array . The first matter is to turn the logits into an actual probability distribution. Recall that a probability distribution assigns a non-negative number to outcome with the requirement that these numbers sum to :

The relevant distribution for us is the probability of the next token being . The logits are not at all guaranteed to sum to , so we cannot just set . Naively, we could enforce this by simply dividing each logit by the total sum of all of the logits:

This, however, only produces a valid probability distribution when all logits are non-negative, which is also not guaranteed. To fix this, we can observe that we can generalize the above by first applying a function to all the logits and then normalizing by the sum:

The question now is how to choose the . At the very least, we want any choice to satisfy the following two properties:

  • non-negativity: for all
  • monotonicity: if then

The first property will guarantee that the output gives us a valid probability distribution. The second property respects the meaning of logits: a token with a higher score should be sampled with a higher probability. A natural choice satisfying both criteria is the exponential function, parameterized by a scalar :

Taken together, we have just defined the softmax function, which takes an array of real numbers and outputs a probability distribution on them:

We call the parameter the temperature, and use it to modulate the entropy—that is, the level of randomness—of the resulting distribution. In the limit , this yields a deterministic output where we simply select the token with the highest logit. In the limit , this yields the uniform distribution across tokens. In practice, temperatures typically range from to , with as a common default.

Sampling Strategies

Beyond temperature, two common filtering strategies restrict the candidate set before sampling. Top- keeps only the highest-probability tokens, zeroing out everything else. Top-, or nucleus sampling, keeps the smallest set of tokens whose cumulative probability exceeds a threshold . Both cut the long tail of low-probability tokens and are typically used independently.

The “keep generating condition” mentioned in complete is simple: generation stops when either the end-of-sequence token—<|endoftext|> in Qwen3’s tokenizer.json—is produced or a maximum token count is reached. These are the two conditions checked by sampler.keep_generating(response).

The full implementation of Sampler in inference.py is straightforward Python packaging the above—temperature scaling, top-/top- filtering, and the stopping condition. The reader can consult it directly for the details.

Templates

At the beginning of the chapter, we made the simplifying assumption to treat the context as raw text. In practice, the context passed to inference is not a bare string but a structured object. A template converts this structured context into the flat text that the model actually sees, and the inverse operation—parsing—processes the model’s raw text output back into structured form. How much structure parsing extracts depends on the scaffolding—it may extract and route data to other parts of the system; our inference.py simply appends the raw response as an assistant message. The template is model-specific: it is distributed alongside the weights and tokenizer. Crucially, the model is trained on data formatted with its template, so the template is not an arbitrary formatting choice—it is baked into the model’s behavior and cannot be swapped freely.

This gives us one more layer of the same pattern. If we let denote the type of structured context objects and define an function that operates on them, then we have the following commutative diagram:

In code, this is the same nesting pattern as generate wrapping complete:

def infer(
    template: Template,
    tokenizer: Tokenizer,
    sampler: Sampler,
    model: Model,
    context: Context,
) -> Context:
    """infer = template ; generate ; parse."""
    text = template.apply(context)                       # template : Context -> str
    response = generate(tokenizer, sampler, model, text) # generate : str -> str 
    return template.parse(context, response)             # parse    : str -> Context

Messages

The simplest and most universal structured context is a sequence of messages alternating between roles. Every model family supports at least three roles: system for persistent instructions, appearing once at the start; user for the human’s input; and assistant for the model’s output. A simple conversation in this format:

[
  {"role": "system",    "content": "You are legendary mathematician Alexander Grothendieck."},
  {"role": "user",      "content": "What is your favorite prime?"},
  {"role": "assistant", "content": "57"}
]

Qwen3’s template converts this to the following flat text, using special tokens to delimit each message:

<|im_start|>system
You are legendary mathematician Alexander Grothendieck.<|im_end|>
<|im_start|>user
What is your favorite prime?<|im_end|>
<|im_start|>assistant
57<|im_end|>

The entire conversation—system prompt, user question, and assistant reply—is one flat token sequence. The special tokens <|im_start|> and <|im_end|>—listed in the added_tokens field of tokenizer.json—are the only structure the model sees; everything between them is ordinary text. This particular delimiter convention is called ChatML, and is the template used by Qwen3. Other model families use different delimiters but the same principle: role-tagged messages flattened into a single token sequence.

Thinking

Recent models support extended thinking—generating a chain of reasoning tokens, as in Chain-of-Thought Prompting, before producing a visible answer. In Qwen3, the model demarcates thinking with <think></think> markers. There is no separate mechanism: thinking and visible tokens are generated in one continuous stream. The template extracts the thinking content when re-rendering the conversation for the next turn.

Tools

Language models can invoke external functions via tool use, also named function calling. The context carries a list of tool definitions—passed to the template, which injects them into the flat text—and when the model decides to call one, it generates a tool call as ordinary output tokens. The application-level scaffolding—code beyond what inference.py implements—evaluates the call and inserts the result back into the conversation. For example, given a tool that performs arithmetic:

{
  "name": "calculate",
  "parameters": {
    "operation": {"type": "str", "enum": ["add", "multiply", "subtract", "divide"]},
    "left":      {"type": "number"},
    "right":     {"type": "number"}
  }
}

the model might respond to “What is 19 × 3?” by generating the token sequence calculate(operation="multiply", left=19, right=3). The scaffolding evaluates it, returns 57 as a tool result message, and the model continues generating with that result in context. As with thinking, tool calls and results are demarcated by text markers—the model produces them as ordinary tokens, and the template layer interprets the structure.

Messages, thinking, and tools are the three content types common to virtually every language model provider today. Model creators can extend the template with additional structured types as the landscape evolves.

The Matryoshka

Let’s now step back and see the full structure at once. Each layer of the system is the layer below it, wrapped with one additional concern. Templates handle structured context. Tokenization handles text representation. Completion handles the autoregressive loop and sampling. What remains——is, as mentioned above, the subject of the third chapter.

At the code level, this nesting is visible in the function signatures themselves. Each layer is the one below with one more configuration argument:

infer(template,tokenizer,sampler,model,context)
generate(      tokenizer,sampler,model,context)
complete(                sampler,model,context)
                                 model(context)

As a mathematical diagram, the full decomposition stacks the two commutative squares we built earlier:

We can now write the entire computation from structured context to structured response, autoregressive loop and all. Writing for a random draw from a distribution :

Now that we have unravelled all of the scaffolding, what remains is the itself. Before we can open it up, we need the mathematical language in which it is written: tensors and their operations.

Tensors

This chapter is a mathematical interlude on computations of tensors. We start by defining tensors, then proceed to defining a set of primitive tensor operations, and conclude with a discussion on how to formally think about and represent computations. We will express every concept in both mathematical notation and executable code, using our domain-specific tensor programming language catform.

Tensor programming is an abstraction for programming massively parallel processors—GPUs, TPUs, and similar hardware that execute thousands of operations simultaneously. The key insight is that tensors, by organizing data into axes, make the available parallelism explicit. There are vastly more functions from tensors to tensors than there are tensor operations. What earns the name is structured parallelism: the axes alone determine which input entries each output entry may depend on. This constrains the space enough that we can describe every computation we need with just six primitive operation types—plus one derived form—all introduced in this chapter. Tensors are therefore a natural mathematical organizing principle for such hardware, and as we will see, they cohere directly with the objects of linear algebra that underpin the mathematics of language models.

Tensors

From the standpoint of a computer, a data object is typically a contiguous array of physical memory. We call the value in a given memory cell an entry. An -axis tensor imbues the set of entries with the structure of an -dimensional array by keeping track of metadata to tell our programs how the entries are logically arranged in memory. For instance, suppose we have contiguous entries

We could treat these entries as comprising, for instance, a -axis tensor

by recording two pieces of metadata:

  • the shape tuple of the array
  • the stride tuple of increments used to traverse each axis

The shape of the above array is , while the stride is since incrementing the first axis (moving down a column) skips addresses, e.g. is in position while the directly below it is in position , while incrementing the second axis (moving along a row) skips .

We will denote the mathematical type of a tensor by the datatype of its entries followed by its shape in square brackets; e.g. for our above tensor we denote its type as

For now, we will use the generic mathematical types and ; later in this chapter we will specify the concrete machine-level numeric types that replace them.

In this chapter’s companion module tensors.py, we represent the tensor type as follows.

type Dim = int | str

@dataclass(frozen=True)
class Tensor:
    dtype: str
    shape: tuple[Dim, ...]

We allow the dimension type Dim to either be an explicit numeric int or a variable str. Allowing for variable dimensions facilitates representing more generic computations, e.g. those that make sense independent of the shape integers. At runtime all string dimensions get resolved to non-negative integers.

In catform programs, we introduce tensors such as above using the literal syntax:

x: int[2,4] = literal([[000, 001, 010, 011],[100, 101, 110, 111]])

All data objects in catform are tensors. This includes plain numbers, or scalars, which we represent as -axis tensors—i.e. those of shape . For instance, a single value has type .

x: real[] = literal(3.14)

In tensors.py, the literal is represented as:

@dataclass(frozen=True)
class Literal:
    value: Value
    dtype: str

Catform also allows introducing random data. The shape comes from the type annotation; the arguments specify the range:

xs: f32[5, 1000] = random(-1.0, 1.0)

In tensors.py, random data is represented as:

@dataclass(frozen=True)
class Random:
    lower: float
    upper: float
    dtype: str
    dims: tuple[Dim, ...]

Tensor Operations

We now introduce catform’s six primitive operations for transforming tensors.

An individual catform operation is given by its operation type and one or more specifiers, denoted in square brackets after the operation type, as follows:

Operations take one or more arguments and output a single value. In catform, we assign the output of an operation as:

y: d[S] = op[spec_1,...,spec_m](x_1, ..., x_n)

In this equation, y is the output, d[S] its type (with datatype d and shape S), op one of the operation types, spec_1,...,spec_m its set of specifiers (in practice, there are at most 2) and x_1, ..., x_n its arguments. In mathematical notation, we write this function type as

using tuple notation in the case of multiple arguments.

Catform allows simple function abstraction. The function declaration syntax is given by its name, its typed arguments in parentheses, an arrow, its typed return values in parentheses, and a body of value assignment lines in curly brackets:

func(x_1: d_1[S_1], ..., x_n: d_n[S_n]) -> (y_1: d'_1[S'_1], ..., y_m: d'_m[S'_m]) {
    ...
}

The function body exclusively consists of value assignment lines—they are the sole way to represent computations in catform. The function’s outputs are whichever assigned variables in the body match the return names declared after the arrow. If there is no such match, the program doesn’t typecheck.

We will now catalogue all six primitive operation types, followed by the derived form. The Pianola codebase defines all of these in the module tensors.py. Each op section ends with a runnable example from book/book.cat; the comment above the function name is the command to run it.

Views

View operations transform a tensor by exclusively changing the shape and stride metadata without making any changes to the actual underlying data.

Recall the example tensor in the prior section. If we applied the same permutation to its shape and stride tuples, e.g. from to and to respectively, we would simply swap the role of rows and columns, thus giving us its transpose

Such a permutation defines a mathematical function of type

While we could give this function a name like or , we use the op type to express this in a far more generalizable manner. All we need is a specifier which denotes exactly how the axes get rearranged. There are many ways one could do this; we choose the convention of a pattern string, as popularized by einops. These are best explained through example. In this transpose case, we write the pattern as . This gives us the mathematical notation for this function:

If we name the input x and output y, we can express the application of this function to inputs in the assignment line

y: int[4,2] = view["a b -> b a"](x)

Another variant of the view operation comes from merging and splitting axes. For instance, suppose we wished to reconceptualize our tensor as a -axis array cube, with -axis layers

The resultant tensor would have shape and stride . To express this computation, we use the pattern "a (b c) -> a b c" to convey that the original second axis (b c) had the potential to be factored into two new axes b and c. In this case, the type signature is critical: it tells us exactly how to numerically factor the axis.

In catform, we express the application of this function as the line

y: int[2,2,2] = view["a (b c) -> a b c"](x)

The interested reader can compute for themselves what happens to the strides in such merges and splits.

Because views merely re-interpret the logical arrangement of entries without altering the number of entries, we note that, since is the multiplicative unit, operations can freely add or remove axes of size from the shape.

The pattern strings we have seen so far name every axis. But real tensors have many axes, and most operations only touch the last few—the axes before them pass through unchanged. Rather than naming each of these pass-through axes, ... matches any number of leading axes that appear identically on both sides of the arrow. For example, "a b -> b a" transposes a 2-axis tensor, while "... a b -> ... b a" transposes the last two axes of a tensor with any number of axes. The axes matched by ... are exactly those where independent copies of the operation run in parallel. Catform also uses ... in type annotations: f32[..., N, M] denotes a tensor with any number of leading axes followed by N and M.

In tensors.py, we represent the view op as the dataclass:

@dataclass(frozen=True)
class View:
    pattern: str
    axes: dict[str, int]

The pattern is the rearrangement specifier. The axes dictionary provides concrete axis sizes to make the output shape unambiguous. Since views only change metadata, they involve no computation and no parallelism—they are free.

// uv run main.py run book/book.cat transpose
transpose() -> (y: int32[3, 2]) {
  x : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
  y : int32[3, 2] = view["a b -> b a"](x)
}

Maps

Map operations involve element-wise operations. The first variant of the map operation is that of lifting a unary (single argument) function

to a function of shape tensors

We name this operation by using the function as the specifier to the op type . This operation is defined by applying the function to each entry.

For example, we can lift the squaring function

to tensors, and apply it to an input:

In catform, we write this as the line

y: int[2,2] = map[square](x)

Maps of unary functions interact nicely with function composition , via the functoriality property:

In words, this means that applying to every entry and then applying to every entry is equal to applying and then to every entry at once. This means that we can combine two consecutive unary map operations into a single combined unary map operation. This fact proves useful for optimizing performance.

The other variant of map operations comes from mapping binary (or even -ary!) operators, i.e. those with two (or more) inputs, such as the classic boolean and arithmetic operators. For instance, the following operation corresponds to adding two -shaped tensors.

In catform, we write this as the line

z: int[S] = map[add](x, y)

In tensors.py, we represent the map op as the dataclass:

@dataclass(frozen=True)
class Map:
    function: str

The sole field is the name of the elementwise function being lifted. Maps are maximally parallel: every entry coordinate can be computed independently, so a tensor with entries yields that many independent computations.

// uv run main.py run book/book.cat square
square() -> (y: int32[2, 2]) {
  x : int32[2, 2] = literal([[1, 2], [3, 4]])
  y : int32[2, 2] = map[mul](x, x)
}

Folds

Fold operations combine all entries along an axis into one, collapsing that axis from size to size . To denote a fold operation, we need two specifiers. We first need a pattern to indicate which axis is being collapsed—by replacing the corresponding axis name with 1 on the right-hand side. For example, "a b c -> a 1 c" collapses axis b. Second, we must specify the reduction being used. A reduction is a function that takes any number of inputs of a given type, i.e. a list, and combines them into a single output, hence having a signature

Common examples of reductions include

  • for the sum and product of a list of numbers
  • for the arithmetic mean of a list of numbers
  • for the minimum and maximum of a list of numbers
  • for checking the truth of all or any of a list of booleans

All of the above reductions are both associative and commutative, so the order in which entries are combined does not matter—the fold is well-defined regardless of how the hardware traverses the axis. The choice of reduction also constrains the entry type: mean requires real-valued entries, while all and any require booleans.

For example, if we wish to sum all of the columns in a tensor, we would use the function

and apply it as so

In catform, we write this as the line

y: int[1,3] = fold["a b -> 1 b", sum](x)

Fold patterns can collapse multiple axes simultaneously, e.g. "a b c -> 1 1 c" folds both a and b—combining entries across two axes is no different from combining across one.

In tensors.py, we represent the fold op as the dataclass:

@dataclass(frozen=True)
class Fold:
    pattern: str
    reduction: str

The pattern identifies which axes collapse; the reduction names the combining function. Folds are parallel across all non-folded axes: in "a b -> 1 b", each of the columns is reduced independently. When a tensor has many axes, ... stands in for leading axes the pattern need not name: "... a b -> ... 1 b" folds axis a regardless of how many axes precede it. The corresponding types use ... as well: f32[..., A, B] -> f32[..., 1, B].

// uv run main.py run book/book.cat col_sum
col_sum() -> (y: int32[1, 3]) {
  x : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
  y : int32[1, 3] = fold["a b -> 1 b", sum](x)
}

Tiles

Tile operations involve replicating every entry along an axis, thus expanding that axis from size to size . Like fold, a tile operation requires a pattern specifier to indicate which axis is being expanded, and we do so via patterns like 1 b -> a b, which indicate that the first axis is being expanded from to . Unlike fold, no reduction is needed—replication has no degrees of freedom.

For example, if we wish to replicate a tensor across rows, we would use the function

and apply it as so

In catform, we write this as the line

y: int[2,3] = tile["1 b -> a b"](x)

Note that the output type int[2,3] is essential here—the pattern "1 b -> a b" introduces a new axis a, and its size is determined by the type annotation.

The reader may notice that tile and fold are closely related: fold collapses an axis via a reduction, tile expands one via replication. This is no coincidence—tile and fold are dual operations, a relationship we will make precise when we discuss gradients.

In tensors.py, we represent the tile op as the dataclass:

@dataclass(frozen=True)
class Tile:
    pattern: str
    axes: dict[str, int]

Like View, the axes dictionary provides concrete axis sizes to make the output shape unambiguous. Tiles are parallel across all non-tiled axes: in "1 b -> a b", each of the entries is replicated independently. Likewise, "... 1 b -> ... a b" tiles regardless of leading axes, with types f32[..., 1, B] -> f32[..., A, B].

// uv run main.py run book/book.cat replicate
replicate() -> (y: int32[3, 4]) {
  x : int32[1, 4] = literal([[1, 2, 3, 4]])
  y : int32[3, 4] = tile["1 b -> a b"](x)
}

Gathers

The four operations above—view, map, fold, and tile—are all structurally determined: their behavior depends entirely on the pattern and the shapes of their inputs. But some computations are data-dependent: the output depends not just on the shape of a tensor but on its values. The canonical example is a table lookup, where the entries of one tensor serve as indices into another. The next two ops handle exactly this.

Gather operations select entries from a tensor using indices from another tensor. Given a data tensor and an index tensor, a gather produces an output whose entries are looked up from the data at the positions specified by the indices.

Like fold and tile, gather uses a pattern specifier to indicate which axis is affected. The marker _ denotes the axis replaced by the index tensor’s shape. For example, the pattern v d -> _ d means: axis v is replaced by the index shape _, while axis d passes through unchanged.

Suppose we have a tensor of data and an tensor of indices

Then the gather selects rows from the data:

The type signature of this operation is

The _ in the output shape corresponds to the index tensor’s size (), while the non-_ axes pass through in their same positions. In catform, this is written as

y: real[3, 2] = gather["v d -> _ d"](data, idx)

Gather implements a lookup table: the indices select entries along an axis of the data.

In tensors.py, we represent the gather op as the dataclass:

@dataclass(frozen=True)
class Gather:
    pattern: str

The pattern identifies which axis is replaced by the index shape. Each output entry requires exactly one lookup, so gathers are parallel across all output coordinates—each index can be followed independently.

// uv run main.py run book/book.cat lookup
lookup() -> (y: int32[2, 3]) {
  data : int32[4, 3] = literal([[10, 20, 30], [40, 50, 60], [70, 80, 90], [100, 110, 120]])
  idx  : int32[2]    = literal([0, 3])
  y    : int32[2, 3] = gather["v d -> _ d"](data, idx)
}

Scatters

Scatter operations are the dual of gather: where gather reads from data at specified indices, scatter writes to a target at specified indices. Given a source tensor, an index tensor, and a target template, scatter accumulates the source entries into the target at the positions specified by the indices.

The scatter pattern is the reverse of the gather pattern. Where gather uses v d -> _ d, scatter uses _ d -> v d: the _ axis in the input is mapped back into the named axis in the output.

Using the same index tensor , we can scatter a source into a target of zeros:

The type signature of this operation is

In catform, we write this as the line

y: real[4,2] = scatter["_ d -> v d"](source, idx, target)

When multiple indices point to the same position, their contributions are summed—this accumulation behavior is precisely what makes scatter the dual of gather, as we will see when we discuss gradients.

In tensors.py, we represent the scatter op as the dataclass:

@dataclass(frozen=True)
class Scatter:
    pattern: str

The pattern is the reverse of the corresponding gather pattern. Scatters are parallel across the non-_ axes—modern hardware handles index collisions via atomic addition, so even overlapping writes proceed in parallel.

// uv run main.py run book/book.cat accumulate
accumulate() -> (y: int32[4, 2]) {
  vals   : int32[3, 2] = literal([[1, 2], [3, 4], [5, 6]])
  idx    : int32[3]    = literal([0, 2, 0])
  target : int32[4, 2] = literal([[0, 0], [0, 0], [0, 0], [0, 0]])
  y      : int32[4, 2] = scatter["_ d -> v d"](vals, idx, target)
}

Tensor Contractions

The above six operations are sufficient to capture the entire mathematical computation of any modern transformer language model. Despite this, we choose to elevate the derived form—i.e. one that can be expressed as a composite of the primitive operations—of tensor contraction because of its semantic centrality to the mathematics of language models. We first review the mathematics, perhaps shedding new light on it, and then explain how it can be expressed in terms of the six primitive operations.

Let’s review some linear algebra. For the time being, forget about tensors—we will connect back to them soon. For this section, we will deviate from some of our notational conventions to show their correspondence with the conventional mathematics notation. Recall that a vector in is by convention denoted as a column vector, e.g. the following vector

Why was there such an emphasis on these vectors being written as columns rather than rows? What semantic meaning did that distinction carry? Recall that a matrix was depicted as a two dimensional array, e.g. the following matrix

A matrix carried further meaning than merely being a two dimensional array of numbers—in particular an matrix—i.e. one with columns and rows—corresponded to a linear map of type

A particularly simple type of matrix is the one with a single row, i.e. a matrix, which we also call an row vector or covector, which is a function of type

This means that if we have an covector and an vector , we have a function of type along with an element of its domain —which means we can apply the function to the vector:

In practice, we obtain a covector by transposing a column vector , and denote the result . With this notation, we drop the parentheses and write the application as:

We compute this value in terms of the entries of the covector and vector using the dot product:

Multiply entry-by-entry, then sum—this is exactly a map followed by a fold. We can write this as a catform program using our six primitive tensor operations. If we view both vector and covector as tensors:

dot_product(w: real[m], v: real[m]) -> (val: real[]) {
    wv  : real[m] = map[mul](w, v)
    val_: real[1] = fold["m -> 1", sum](wv)
    val : real[]  = view["1 -> "](val_)
}

The dot product recurs in the context of applying an matrix function to an vector . It in fact can be used to give an interpretation of what an matrix does, by interpreting it as parallel covectors, each applied independently to the argument to arrive at an vector:

To write this in catform, we do so as above, except must first tile the vector to be the same shape as the matrix:

matrix_apply(w: real[n,m], v: real[m]) -> (vect: real[n]) {
    vv   : real[n,m] = tile["m -> n m"](v)
    wvv  : real[n,m] = map[mul](w, vv)
    vect_: real[n,1] = fold["n m -> n 1", sum](wvv)
    vect : real[n]   = view["n 1 -> n"](vect_)
}

We can extend this idea to the matrix multiplication of an matrix by an matrix—by viewing, as before, the former as parallel covectors, and the latter as parallel vectors. This allows us to think of matrix multiplication as worth of independent dot products.

To write this in catform, we do so as above, except must now tile both of the matrices to be the same shape:

matrix_multiply(w: real[n,m], v: real[m,k]) -> (mat: real[n,k]) {
    ww   : real[n,m,k] = tile["n m -> n m k"](w)
    vv   : real[n,m,k] = tile["m k -> n m k"](v)
    wwvv : real[n,m,k] = map[mul](ww, vv)
    mat_ : real[n,1,k] = fold["n m k -> n 1 k", sum](wwvv)
    mat  : real[n,k]   = view["n 1 k -> n k"](mat_)
}

Higher dimensional variants of such computations become harder to visualize, but the operation remains well defined. For instance, we may have an rectangular grid of parallel covectors and a rectangular prism grid of parallel vectors, and calculate dot products. Although the visuals become more challenging to imagine, the catform does not—this example is nearly identical to the above:

tensor_contraction(w: real[o,n,m], v: real[m,k,j,i]) -> (tens: real[o,n,k,j,i]) {
    ww    : real[o,n,m,k,j,i] = tile["o n m -> o n m k j i"](w)
    vv    : real[o,n,m,k,j,i] = tile["m k j i -> o n m k j i"](v)
    wwvv  : real[o,n,m,k,j,i] = map[mul](ww, vv)
    tens_ : real[o,n,1,k,j,i] = fold["o n m k j i -> o n 1 k j i", sum](wwvv)
    tens  : real[o,n,k,j,i]   = view["o n 1 k j i -> o n k j i"](tens_)
}

Just as we easily expanded the number of axes of the set of dot product calculations, we could also have expanded the number of axes that we summed across—there is nothing special about the covector, vector pair that enforces them to be single axis! The degenerate case has zero contracting axes—no shared indices are summed over. The result is the outer product: every entry of one tensor paired with every entry of the other, as in "n, d -> n d". We call the summed-over axes contracting axes (often called contracting dimensions elsewhere). This general operation is so common—it in fact comprises the vast majority of arithmetic operations of any language model—and carries so much mathematical semantics—which we have to restrain ourselves from explicating further—that we elevate it to a derived form that we give native status to in addition to the six primitive operations above. Just as many other operations, we specify such contractions with a pattern as follows:

where can all be any list of values. In catform this looks like:

z: d[as, cs] = contract["as bs, bs cs -> as cs"](x, y)

In tensors.py, we represent the contract op as the dataclass:

@dataclass(frozen=True)
class Contract:
    pattern: str

The pattern encodes both inputs and output in a single string—the contracting axes are those that appear in both inputs but not in the output. Contractions are parallel across all non-contracting axes: in "n m, m k -> n k", the output entries are independent dot products, each computed in parallel. Likewise, "... n d, o d -> ... n o" contracts over d regardless of leading axes. When ... appears on both inputs—as in "... n h e, ... s h e -> ... h n s"—it matches the same set of leading axes in each.

The four contraction specializations demonstrated above—dot product, matrix-vector, matrix multiply, and outer product—each reduce to a single contract op:

// uv run main.py run book/book.cat dot
dot() -> (val: int32[]) {
  w   : int32[4] = literal([1, 2, 3, 4])
  v   : int32[4] = literal([5, 6, 7, 8])
  val : int32[]  = contract["m, m ->"](w, v)
}

// uv run main.py run book/book.cat mat_vec
mat_vec() -> (y: int32[2]) {
  w : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
  v : int32[3]    = literal([7, 8, 9])
  y : int32[2]    = contract["n m, m -> n"](w, v)
}

// uv run main.py run book/book.cat mat_mul
mat_mul() -> (y: int32[2, 2]) {
  a : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
  b : int32[3, 2] = literal([[7, 8], [9, 10], [11, 12]])
  y : int32[2, 2] = contract["n m, m k -> n k"](a, b)
}

// uv run main.py run book/book.cat outer
outer() -> (y: int32[3, 4]) {
  u : int32[3]    = literal([1, 2, 3])
  v : int32[4]    = literal([4, 5, 6, 7])
  y : int32[3, 4] = contract["n, m -> n m"](u, v)
}

Together with the Op wrapper and the Function and Module structures for composition, these dataclasses comprise tensors.py—the complete catform abstract syntax tree (AST).

Computation Graphs

Every catform function we have written—dot_product, matrix_apply, and the op examples like transpose and col_sum—has the same shape: a list of named operations, each consuming some outputs from earlier ones. This structure has a natural visual representation called a string diagram. Each operation is a box; each typed tensor flowing between operations is a wire. Reading top to bottom traces execution order. The dotted boundary marks a function—nesting one inside another is composition.

function(t: T, x: X) -> (z: Z) {
    (u,v): (U,V) = f(t)
    w    : W     = g(x)
    y    : Y     = h(t,u,v,w)
    z    : Z     = i(y)
}

Data enters a catform function in two ways. Arguments are declared in the function signature and supplied by the caller—in the diagram, t and x pierce the top boundary. Introductions—values created by literal(...) or random(...)—originate inside the function with no incoming wire.

Two structural features are free, requiring no operation. A wire can fork: t feeds both f and h—the same tensor is available to multiple consumers without copying. And a box can emit multiple outputs: f produces both u and v. Neither forking nor multi-return requires an op; they are structural features of the language.

Independence is visible. Boxes f and g share no wires; they can execute in either order, or simultaneously. Parallelism is not annotated—it is read off the topology.

Every edge only flows downward—the diagram is a directed acyclic graph. Each wire name appears as an output of exactly one box, a property called static single assignment. The correspondence between a catform function and its string diagram is exact: every line in the code is a box, every named intermediate is a wire, and the top-to-bottom reading order matches the textual order up to reorderings that respect data dependencies.

Composition and Flattening

A single catform function is a straight-line program. But a .cat file contains a module—a collection of named functions that can reference each other. Two composition constructs allow structuring a computation across multiple functions:

  • call[f](x_1, ..., x_n) invokes the function f defined elsewhere in the module, passing arguments and receiving its returns.
  • loop[f, count](args) is a macro that expands into count sequential invocations of f. For example, the loop
y : T = loop[f, 3](y, s, w)

expands to

y_1 : T = call[f](y,   s, w.0)
y_2 : T = call[f](y_1, s, w.1)
y_3 : T = call[f](y_2, s, w.2)

The argument whose name matches the output (y) is threaded—SSA renaming chains each iteration’s output to the next iteration’s input (y → y_1 → y_2 → y_3). Arguments of dict type (*) are indexed per iteration via .0, .1, .2. All other arguments are static. Subsequent references to y are rewritten to y_3.

Neither call nor loop is an op type—they are macros, not computations. They organize the program for human comprehension: call factors repeated patterns into named functions; loop expresses iterated structure without copy-pasting. Both disappear under flattening: every call is inlined, every loop is unrolled, producing a single function called main—a straight-line sequence of the seven op types.

Two details remain before we put these operations to work: the concrete numeric formats that replace real and int, and a machine to run them.

Numerics

Everything above used real and int as mathematical types. A machine has a finite number of bits, so every numeric value it stores must fit in a fixed-width binary word. The consequences for integers are straightforward: a -bit two’s complement integer represents values in . Overflow wraps; there is no approximation within range. In catform, int32 denotes a -bit signed integer.

Real numbers are infinitely precise, so approximation is unavoidable. To represent them, we must choose a numeric representation. The standard format is given by the IEEE 754 floating-point specification. We can derive this representation as follows. Given a real number , we can first extract its if non-negative and if negative—and rewrite it in terms of a non-negative real :

Then, let denote the unique integer such that

This lets us further deconstruct as

Note that by the choice of , is always in the interval . We make the substitution , where is a non-negative real number in the interval and hence can be expressed in binary as . This gives us the unique decomposition used by floating point numbers:

The triple (sign, exponent, mantissa) uniquely characterizes a floating-point value. A floating-point type is determined by choosing how many bits to allocate to each component.

A given floating point representation has a fixed number —always a power of —of bits. It always allocates bit to the sign, and then assigns for the exponent and for the mantissa. Assigning more bits to the exponent allows the representation to capture a wider range of numbers—much larger or much smaller—while assigning more to the mantissa gives higher precision—more binary digits before rounding occurs. Many numerical formats—specified by , and —have been used in numerical computing. A given language model might use distinct formats in different parts of its computation. In the next chapter, we will see such choices in a full language model computation.

Catform rewrites—flattening, inlining, reordering—are mathematically equivalent restructurings that never alter the entry-level arithmetic or dtypes.

Tensor Programming

We now have all the ingredients: tensors and their types, seven operation types plus contraction, introductions for constants and random data, and concrete numeric dtypes. Before we turn to language models, we put all of these components together in a single computation—with no machine learning in sight. Inspired by the excellent 3Blue1Brown lecture, we compute the -volume of the unit -ball, defined as the set of points in with distance at most from the origin. Its volume is a classical quantity that at first grows and then shrinks as increases—the unit ball becomes vanishingly small in high dimensions:

NameVolume
segment2.00
disk3.14
ball4.19
-ball4.93
-ball5.26
-ball2.55
-ball2.37 * 10^{-40}

The volume peaks at and then declines. The closed forms for even and odd dimensions are

Since both formulae have factorials in the denominator that eventually dominate the powers of , the volume converges to zero. We can estimate it empirically via Monte Carlo sampling. The function takes the dimension and sample count as parameters, then generates random data inside—showcasing introductions. For each of random points, we sample coordinates uniformly from the interval .

A point lies inside the ball when

The volume can be calculated entirely via catform operations:

// uv run main.py run book/book.cat ball_vol 5 10000
ball_vol(n: int32[], S: int32[]) -> (vol: f32[]) {
    x     : f32[n, S] = random(-1.0, 1.0)
    
    // compute the square norm for every sample
    sq    : f32[n, S] = map[mul](x, x)
    norm  : f32[1, S] = fold["n S -> 1 S", sum](sq)
    
    // compute the ratio of points in the ball
    one   : f32[]     = literal(1.0)
    one_t : f32[1, S] = tile["-> 1 S"](one)
    mask  : f32[1, S] = map[le](norm, one_t)
    inside: f32[1, S] = map[f32](mask)
    frac  : f32[1, 1] = fold["1 S -> 1 1", mean](inside)
    ratio : f32[]     = view["1 1 ->"](frac)
    
    // scale the ratio by the volume of the hypercube
    two   : f32[]     = literal(2.0)
    n_f   : f32[]     = map[f32](n)
    side  : f32[]     = map[pow](two, n_f)
    vol   : f32[]     = map[mul](ratio, side)
}

Line by line: random introduces an tensor of uniform samples. map[mul] squares every coordinate. fold[sum] collapses the axis, yielding a squared norm for each sample. map[le] tests whether each norm is at most —whether the point is inside the ball—producing a boolean. map[f32] casts to float so that fold[mean] can average across all samples, and view strips the unit axes to a scalar. The final three lines compute —the volume of the hypercube—and multiply to convert the fraction into a volume.

With and large , ball_vol returns approximately 3.14 (). The operations that will specify a language model are not specific to language models—they are general-purpose tensor operations. A forward pass through a transformer and a Monte Carlo integration are the same kind of thing: a straight-line composition of maps and folds.

Naive Execution

Our catform tensor operations and their use in .cat modules are specifications—they say what to compute, not how. To actually run the ops, we need an executor that lowers each op—translates it from an abstract description to a concrete function call in a tensor computing library. The two dominant frameworks for tensor computation at the time of writing are PyTorch and JAX. Pianola’s simplest executor lives in lower/simple.py and dispatches each op to a single call in either one.

The pattern-string specifiers used by view, fold, tile, and contract were directly inspired by einops, a library that provides framework-agnostic tensor operations with the same pattern-string interface. We use einops as the interop layer: since catform’s specifiers and einops’s patterns are the same language, the lowering is almost trivial—the pattern passes straight through. The complete dispatch from _resolve_op in simple.py maps each op type to its concrete implementation:

case View(pattern=p, axes=ax):     return einops.rearrange(x, p, **ax)
case Map(function=f):              return MAP_DISPATCH[f](backend, args)
case Fold(pattern=p, reduction=r): return einops.reduce(x, p, r)
case Tile(pattern=p, axes=ax):     return einops.repeat(x, p, **ax)
case Gather(pattern=p):            return data[indices]
case Scatter(pattern=p):           return scatter(choice, values, indices, template, axis)
case Contract(pattern=p):          return einops.einsum(a, b, p)

view, fold, tile, and contract pass their pattern strings directly to einops. gather lowers to the framework’s native advanced indexing, which both PyTorch and JAX handle identically. scatter is the one op that requires its own per-backend dispatch function, because the accumulation primitives differ: PyTorch provides index_add_ (in-place), while JAX provides .at[].add() (functional). Neither einops nor operator overloading can bridge this gap, so scatter gets a standalone function that cases on the backend:

def scatter(framework: Framework, values, indices, template, axis):
    match framework:
        case Framework.torch:
            return torch.zeros_like(template).index_add_(axis, indices, values)
        case Framework.jax:
            return jnp.zeros_like(template).at[idx].add(values)

map ops split into three groups. Arithmetic binary operators (add, mul, sub, div, ge) use Python’s built-in operators, which both PyTorch and JAX overload identically. where is a ternary operator—it takes a boolean mask and two value tensors, returning entries from one or the other. This is how tensor programs handle conditional selection: every entry is computed for both alternatives, and where picks between them elementwise, which avoids control flow. Finally, unary real-valued functions (exp, cos, sin, silu, rsqrt) are the places where the framework’s namespace is needed. The list of these functions is extensible, without needing to touch the core op primitives: adding a new scalar function (say, tanh) can be done via one new slot in a dispatch table. The framework-specific slots are collected in a Backend dataclass:

@dataclass(frozen=True)
class Backend:
    exp: Callable
    cos: Callable
    sin: Callable
    silu: Callable
    rsqrt: Callable
    where: Callable
    ...

A PyTorch backend fills these with torch.exp, torch.cos, etc, while a JAX backend fills them with jnp.exp, jnp.cos, etc. The interpreter does not know or care which one it received—the .cat specification is portable across any backend that fills the same slots. The entire interpreter is four lines:

def run(ops: list[Op], env: dict[str, Any], backend: Backend) -> None:
    for op in ops:
        args = tuple(env[name] for name in op.inputs)
        env[op.output] = _resolve_op(op, backend)(*args)

Walk the ops in order, look up each input by name, call the resolved function, store the result. No graph construction, no tracing, no runtime code generation—just a loop over operations. The same .cat file, the same run() function, two different Backend values: one for PyTorch, one for JAX. The results are identical up to rounding noise.

This is a line-by-line interpreter—each step materializes its result, making every intermediate inspectable.

We now have the mathematical language to describe, and the computational translation to execute, the full computation of a language model.