Tensors
This chapter is a mathematical interlude on computations of tensors. We start by defining tensors, then proceed to defining a set of primitive tensor operations, and conclude with a discussion on how to formally think about and represent computations. We will express every concept in both mathematical notation and executable code, using our domain-specific tensor programming language catform.
Tensor programming is an abstraction for programming massively parallel processors—GPUs, TPUs, and similar hardware that execute thousands of operations simultaneously. The key insight is that tensors, by organizing data into axes, make the available parallelism explicit. There are vastly more functions from tensors to tensors than there are tensor operations. What earns the name is structured parallelism: the axes alone determine which input entries each output entry may depend on. This constrains the space enough that we can describe every computation we need with just six primitive operation types—plus one derived form—all introduced in this chapter. Tensors are therefore a natural mathematical organizing principle for such hardware, and as we will see, they cohere directly with the objects of linear algebra that underpin the mathematics of language models.
Tensors
From the standpoint of a computer, a data object is typically a contiguous array of physical memory. We call the value in a given memory cell an entry. An -axis tensor imbues the set of entries with the structure of an -dimensional array by keeping track of metadata to tell our programs how the entries are logically arranged in memory. For instance, suppose we have contiguous entries
We could treat these entries as comprising, for instance, a -axis tensor
by recording two pieces of metadata:
- the shape tuple of the array
- the stride tuple of increments used to traverse each axis
The shape of the above array is , while the stride is since incrementing the first axis (moving down a column) skips addresses, e.g. is in position while the directly below it is in position , while incrementing the second axis (moving along a row) skips .
We will denote the mathematical type of a tensor by the datatype of its entries followed by its shape in square brackets; e.g. for our above tensor we denote its type as
For now, we will use the generic mathematical types and ; later in this chapter we will specify the concrete machine-level numeric types that replace them.
In this chapter’s companion module tensors.py, we represent the tensor type as follows.
type Dim = int | str
@dataclass(frozen=True)
class Tensor:
dtype: str
shape: tuple[Dim, ...]
We allow the dimension type Dim to either be an explicit numeric int or a variable str. Allowing for variable dimensions facilitates representing more generic computations, e.g. those that make sense independent of the shape integers. At runtime all string dimensions get resolved to non-negative integers.
In catform programs, we introduce tensors such as above using the literal syntax:
x: int[2,4] = literal([[000, 001, 010, 011],[100, 101, 110, 111]])
All data objects in catform are tensors. This includes plain numbers, or scalars, which we represent as -axis tensors—i.e. those of shape . For instance, a single value has type .
x: real[] = literal(3.14)
In tensors.py, the literal is represented as:
@dataclass(frozen=True)
class Literal:
value: Value
dtype: str
Catform also allows introducing random data. The shape comes from the type annotation; the arguments specify the range:
xs: f32[5, 1000] = random(-1.0, 1.0)
In tensors.py, random data is represented as:
@dataclass(frozen=True)
class Random:
lower: float
upper: float
dtype: str
dims: tuple[Dim, ...]
Tensor Operations
We now introduce catform’s six primitive operations for transforming tensors.
An individual catform operation is given by its operation type and one or more specifiers, denoted in square brackets after the operation type, as follows:
Operations take one or more arguments and output a single value. In catform, we assign the output of an operation as:
y: d[S] = op[spec_1,...,spec_m](x_1, ..., x_n)
In this equation, y is the output, d[S] its type (with datatype d and shape S), op one of the operation types, spec_1,...,spec_m its set of specifiers (in practice, there are at most 2) and x_1, ..., x_n its arguments. In mathematical notation, we write this function type as
using tuple notation in the case of multiple arguments.
Catform allows simple function abstraction. The function declaration syntax is given by its name, its typed arguments in parentheses, an arrow, its typed return values in parentheses, and a body of value assignment lines in curly brackets:
func(x_1: d_1[S_1], ..., x_n: d_n[S_n]) -> (y_1: d'_1[S'_1], ..., y_m: d'_m[S'_m]) {
...
}
The function body exclusively consists of value assignment lines—they are the sole way to represent computations in catform. The function’s outputs are whichever assigned variables in the body match the return names declared after the arrow. If there is no such match, the program doesn’t typecheck.
We will now catalogue all six primitive operation types, followed by the derived form. The Pianola codebase defines all of these in the module tensors.py. Each op section ends with a runnable example from book/book.cat; the comment above the function name is the command to run it.
Views
View operations transform a tensor by exclusively changing the shape and stride metadata without making any changes to the actual underlying data.
Recall the example tensor in the prior section. If we applied the same permutation to its shape and stride tuples, e.g. from to and to respectively, we would simply swap the role of rows and columns, thus giving us its transpose
Such a permutation defines a mathematical function of type
While we could give this function a name like or , we use the op type to express this in a far more generalizable manner. All we need is a specifier which denotes exactly how the axes get rearranged. There are many ways one could do this; we choose the convention of a pattern string, as popularized by einops. These are best explained through example. In this transpose case, we write the pattern as . This gives us the mathematical notation for this function:
If we name the input x and output y, we can express the application of this function to inputs in the assignment line
y: int[4,2] = view["a b -> b a"](x)
Another variant of the view operation comes from merging and splitting axes. For instance, suppose we wished to reconceptualize our tensor as a -axis array cube, with -axis layers
The resultant tensor would have shape and stride . To express this computation, we use the pattern "a (b c) -> a b c" to convey that the original second axis (b c) had the potential to be factored into two new axes b and c. In this case, the type signature is critical: it tells us exactly how to numerically factor the axis.
In catform, we express the application of this function as the line
y: int[2,2,2] = view["a (b c) -> a b c"](x)
The interested reader can compute for themselves what happens to the strides in such merges and splits.
Because views merely re-interpret the logical arrangement of entries without altering the number of entries, we note that, since is the multiplicative unit, operations can freely add or remove axes of size from the shape.
The pattern strings we have seen so far name every axis. But real tensors have many axes, and most operations only touch the last few—the axes before them pass through unchanged. Rather than naming each of these pass-through axes, ... matches any number of leading axes that appear identically on both sides of the arrow. For example, "a b -> b a" transposes a 2-axis tensor, while "... a b -> ... b a" transposes the last two axes of a tensor with any number of axes. The axes matched by ... are exactly those where independent copies of the operation run in parallel. Catform also uses ... in type annotations: f32[..., N, M] denotes a tensor with any number of leading axes followed by N and M.
In tensors.py, we represent the view op as the dataclass:
@dataclass(frozen=True)
class View:
pattern: str
axes: dict[str, int]
The pattern is the rearrangement specifier. The axes dictionary provides concrete axis sizes to make the output shape unambiguous. Since views only change metadata, they involve no computation and no parallelism—they are free.
// uv run main.py run book/book.cat transpose
transpose() -> (y: int32[3, 2]) {
x : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
y : int32[3, 2] = view["a b -> b a"](x)
}
Maps
Map operations involve element-wise operations. The first variant of the map operation is that of lifting a unary (single argument) function
to a function of shape tensors
We name this operation by using the function as the specifier to the op type . This operation is defined by applying the function to each entry.
For example, we can lift the squaring function
to tensors, and apply it to an input:
In catform, we write this as the line
y: int[2,2] = map[square](x)
Maps of unary functions interact nicely with function composition , via the functoriality property:
In words, this means that applying to every entry and then applying to every entry is equal to applying and then to every entry at once. This means that we can combine two consecutive unary map operations into a single combined unary map operation. This fact proves useful for optimizing performance.
The other variant of map operations comes from mapping binary (or even -ary!) operators, i.e. those with two (or more) inputs, such as the classic boolean and arithmetic operators. For instance, the following operation corresponds to adding two -shaped tensors.
In catform, we write this as the line
z: int[S] = map[add](x, y)
In tensors.py, we represent the map op as the dataclass:
@dataclass(frozen=True)
class Map:
function: str
The sole field is the name of the elementwise function being lifted. Maps are maximally parallel: every entry coordinate can be computed independently, so a tensor with entries yields that many independent computations.
// uv run main.py run book/book.cat square
square() -> (y: int32[2, 2]) {
x : int32[2, 2] = literal([[1, 2], [3, 4]])
y : int32[2, 2] = map[mul](x, x)
}
Folds
Fold operations combine all entries along an axis into one, collapsing that axis from size to size . To denote a fold operation, we need two specifiers. We first need a pattern to indicate which axis is being collapsed—by replacing the corresponding axis name with 1 on the right-hand side. For example, "a b c -> a 1 c" collapses axis b. Second, we must specify the reduction being used. A reduction is a function that takes any number of inputs of a given type, i.e. a list, and combines them into a single output, hence having a signature
Common examples of reductions include
- for the sum and product of a list of numbers
- for the arithmetic mean of a list of numbers
- for the minimum and maximum of a list of numbers
- for checking the truth of all or any of a list of booleans
All of the above reductions are both associative and commutative, so the order in which entries are combined does not matter—the fold is well-defined regardless of how the hardware traverses the axis. The choice of reduction also constrains the entry type: mean requires real-valued entries, while all and any require booleans.
For example, if we wish to sum all of the columns in a tensor, we would use the function
and apply it as so
In catform, we write this as the line
y: int[1,3] = fold["a b -> 1 b", sum](x)
Fold patterns can collapse multiple axes simultaneously, e.g. "a b c -> 1 1 c" folds both a and b—combining entries across two axes is no different from combining across one.
In tensors.py, we represent the fold op as the dataclass:
@dataclass(frozen=True)
class Fold:
pattern: str
reduction: str
The pattern identifies which axes collapse; the reduction names the combining function. Folds are parallel across all non-folded axes: in "a b -> 1 b", each of the columns is reduced independently. When a tensor has many axes, ... stands in for leading axes the pattern need not name: "... a b -> ... 1 b" folds axis a regardless of how many axes precede it. The corresponding types use ... as well: f32[..., A, B] -> f32[..., 1, B].
// uv run main.py run book/book.cat col_sum
col_sum() -> (y: int32[1, 3]) {
x : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
y : int32[1, 3] = fold["a b -> 1 b", sum](x)
}
Tiles
Tile operations involve replicating every entry along an axis, thus expanding that axis from size to size . Like fold, a tile operation requires a pattern specifier to indicate which axis is being expanded, and we do so via patterns like 1 b -> a b, which indicate that the first axis is being expanded from to . Unlike fold, no reduction is needed—replication has no degrees of freedom.
For example, if we wish to replicate a tensor across rows, we would use the function
and apply it as so
In catform, we write this as the line
y: int[2,3] = tile["1 b -> a b"](x)
Note that the output type int[2,3] is essential here—the pattern "1 b -> a b" introduces a new axis a, and its size is determined by the type annotation.
The reader may notice that tile and fold are closely related: fold collapses an axis via a reduction, tile expands one via replication. This is no coincidence—tile and fold are dual operations, a relationship we will make precise when we discuss gradients.
In tensors.py, we represent the tile op as the dataclass:
@dataclass(frozen=True)
class Tile:
pattern: str
axes: dict[str, int]
Like View, the axes dictionary provides concrete axis sizes to make the output shape unambiguous. Tiles are parallel across all non-tiled axes: in "1 b -> a b", each of the entries is replicated independently. Likewise, "... 1 b -> ... a b" tiles regardless of leading axes, with types f32[..., 1, B] -> f32[..., A, B].
// uv run main.py run book/book.cat replicate
replicate() -> (y: int32[3, 4]) {
x : int32[1, 4] = literal([[1, 2, 3, 4]])
y : int32[3, 4] = tile["1 b -> a b"](x)
}
Gathers
The four operations above—view, map, fold, and tile—are all structurally determined: their behavior depends entirely on the pattern and the shapes of their inputs. But some computations are data-dependent: the output depends not just on the shape of a tensor but on its values. The canonical example is a table lookup, where the entries of one tensor serve as indices into another. The next two ops handle exactly this.
Gather operations select entries from a tensor using indices from another tensor. Given a data tensor and an index tensor, a gather produces an output whose entries are looked up from the data at the positions specified by the indices.
Like fold and tile, gather uses a pattern specifier to indicate which axis is affected. The marker _ denotes the axis replaced by the index tensor’s shape. For example, the pattern v d -> _ d means: axis v is replaced by the index shape _, while axis d passes through unchanged.
Suppose we have a tensor of data and an tensor of indices
Then the gather selects rows from the data:
The type signature of this operation is
The _ in the output shape corresponds to the index tensor’s size (), while the non-_ axes pass through in their same positions. In catform, this is written as
y: real[3, 2] = gather["v d -> _ d"](data, idx)
Gather implements a lookup table: the indices select entries along an axis of the data.
In tensors.py, we represent the gather op as the dataclass:
@dataclass(frozen=True)
class Gather:
pattern: str
The pattern identifies which axis is replaced by the index shape. Each output entry requires exactly one lookup, so gathers are parallel across all output coordinates—each index can be followed independently.
// uv run main.py run book/book.cat lookup
lookup() -> (y: int32[2, 3]) {
data : int32[4, 3] = literal([[10, 20, 30], [40, 50, 60], [70, 80, 90], [100, 110, 120]])
idx : int32[2] = literal([0, 3])
y : int32[2, 3] = gather["v d -> _ d"](data, idx)
}
Scatters
Scatter operations are the dual of gather: where gather reads from data at specified indices, scatter writes to a target at specified indices. Given a source tensor, an index tensor, and a target template, scatter accumulates the source entries into the target at the positions specified by the indices.
The scatter pattern is the reverse of the gather pattern. Where gather uses v d -> _ d, scatter uses _ d -> v d: the _ axis in the input is mapped back into the named axis in the output.
Using the same index tensor , we can scatter a source into a target of zeros:
The type signature of this operation is
In catform, we write this as the line
y: real[4,2] = scatter["_ d -> v d"](source, idx, target)
When multiple indices point to the same position, their contributions are summed—this accumulation behavior is precisely what makes scatter the dual of gather, as we will see when we discuss gradients.
In tensors.py, we represent the scatter op as the dataclass:
@dataclass(frozen=True)
class Scatter:
pattern: str
The pattern is the reverse of the corresponding gather pattern. Scatters are parallel across the non-_ axes—modern hardware handles index collisions via atomic addition, so even overlapping writes proceed in parallel.
// uv run main.py run book/book.cat accumulate
accumulate() -> (y: int32[4, 2]) {
vals : int32[3, 2] = literal([[1, 2], [3, 4], [5, 6]])
idx : int32[3] = literal([0, 2, 0])
target : int32[4, 2] = literal([[0, 0], [0, 0], [0, 0], [0, 0]])
y : int32[4, 2] = scatter["_ d -> v d"](vals, idx, target)
}
Tensor Contractions
The above six operations are sufficient to capture the entire mathematical computation of any modern transformer language model. Despite this, we choose to elevate the derived form—i.e. one that can be expressed as a composite of the primitive operations—of tensor contraction because of its semantic centrality to the mathematics of language models. We first review the mathematics, perhaps shedding new light on it, and then explain how it can be expressed in terms of the six primitive operations.
Let’s review some linear algebra. For the time being, forget about tensors—we will connect back to them soon. For this section, we will deviate from some of our notational conventions to show their correspondence with the conventional mathematics notation. Recall that a vector in is by convention denoted as a column vector, e.g. the following vector
Why was there such an emphasis on these vectors being written as columns rather than rows? What semantic meaning did that distinction carry? Recall that a matrix was depicted as a two dimensional array, e.g. the following matrix
A matrix carried further meaning than merely being a two dimensional array of numbers—in particular an matrix—i.e. one with columns and rows—corresponded to a linear map of type
A particularly simple type of matrix is the one with a single row, i.e. a matrix, which we also call an row vector or covector, which is a function of type
This means that if we have an covector and an vector , we have a function of type along with an element of its domain —which means we can apply the function to the vector:
In practice, we obtain a covector by transposing a column vector , and denote the result . With this notation, we drop the parentheses and write the application as:
We compute this value in terms of the entries of the covector and vector using the dot product:
Multiply entry-by-entry, then sum—this is exactly a map followed by a fold. We can write this as a catform program using our six primitive tensor operations. If we view both vector and covector as tensors:
dot_product(w: real[m], v: real[m]) -> (val: real[]) {
wv : real[m] = map[mul](w, v)
val_: real[1] = fold["m -> 1", sum](wv)
val : real[] = view["1 -> "](val_)
}
The dot product recurs in the context of applying an matrix function to an vector . It in fact can be used to give an interpretation of what an matrix does, by interpreting it as parallel covectors, each applied independently to the argument to arrive at an vector:
To write this in catform, we do so as above, except must first tile the vector to be the same shape as the matrix:
matrix_apply(w: real[n,m], v: real[m]) -> (vect: real[n]) {
vv : real[n,m] = tile["m -> n m"](v)
wvv : real[n,m] = map[mul](w, vv)
vect_: real[n,1] = fold["n m -> n 1", sum](wvv)
vect : real[n] = view["n 1 -> n"](vect_)
}
We can extend this idea to the matrix multiplication of an matrix by an matrix—by viewing, as before, the former as parallel covectors, and the latter as parallel vectors. This allows us to think of matrix multiplication as worth of independent dot products.
To write this in catform, we do so as above, except must now tile both of the matrices to be the same shape:
matrix_multiply(w: real[n,m], v: real[m,k]) -> (mat: real[n,k]) {
ww : real[n,m,k] = tile["n m -> n m k"](w)
vv : real[n,m,k] = tile["m k -> n m k"](v)
wwvv : real[n,m,k] = map[mul](ww, vv)
mat_ : real[n,1,k] = fold["n m k -> n 1 k", sum](wwvv)
mat : real[n,k] = view["n 1 k -> n k"](mat_)
}
Higher dimensional variants of such computations become harder to visualize, but the operation remains well defined. For instance, we may have an rectangular grid of parallel covectors and a rectangular prism grid of parallel vectors, and calculate dot products. Although the visuals become more challenging to imagine, the catform does not—this example is nearly identical to the above:
tensor_contraction(w: real[o,n,m], v: real[m,k,j,i]) -> (tens: real[o,n,k,j,i]) {
ww : real[o,n,m,k,j,i] = tile["o n m -> o n m k j i"](w)
vv : real[o,n,m,k,j,i] = tile["m k j i -> o n m k j i"](v)
wwvv : real[o,n,m,k,j,i] = map[mul](ww, vv)
tens_ : real[o,n,1,k,j,i] = fold["o n m k j i -> o n 1 k j i", sum](wwvv)
tens : real[o,n,k,j,i] = view["o n 1 k j i -> o n k j i"](tens_)
}
Just as we easily expanded the number of axes of the set of dot product calculations, we could also have expanded the number of axes that we summed across—there is nothing special about the covector, vector pair that enforces them to be single axis! The degenerate case has zero contracting axes—no shared indices are summed over. The result is the outer product: every entry of one tensor paired with every entry of the other, as in "n, d -> n d". We call the summed-over axes contracting axes (often called contracting dimensions elsewhere). This general operation is so common—it in fact comprises the vast majority of arithmetic operations of any language model—and carries so much mathematical semantics—which we have to restrain ourselves from explicating further—that we elevate it to a derived form that we give native status to in addition to the six primitive operations above. Just as many other operations, we specify such contractions with a pattern as follows:
where can all be any list of values. In catform this looks like:
z: d[as, cs] = contract["as bs, bs cs -> as cs"](x, y)
In tensors.py, we represent the contract op as the dataclass:
@dataclass(frozen=True)
class Contract:
pattern: str
The pattern encodes both inputs and output in a single string—the contracting axes are those that appear in both inputs but not in the output. Contractions are parallel across all non-contracting axes: in "n m, m k -> n k", the output entries are independent dot products, each computed in parallel. Likewise, "... n d, o d -> ... n o" contracts over d regardless of leading axes. When ... appears on both inputs—as in "... n h e, ... s h e -> ... h n s"—it matches the same set of leading axes in each.
The four contraction specializations demonstrated above—dot product, matrix-vector, matrix multiply, and outer product—each reduce to a single contract op:
// uv run main.py run book/book.cat dot
dot() -> (val: int32[]) {
w : int32[4] = literal([1, 2, 3, 4])
v : int32[4] = literal([5, 6, 7, 8])
val : int32[] = contract["m, m ->"](w, v)
}
// uv run main.py run book/book.cat mat_vec
mat_vec() -> (y: int32[2]) {
w : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
v : int32[3] = literal([7, 8, 9])
y : int32[2] = contract["n m, m -> n"](w, v)
}
// uv run main.py run book/book.cat mat_mul
mat_mul() -> (y: int32[2, 2]) {
a : int32[2, 3] = literal([[1, 2, 3], [4, 5, 6]])
b : int32[3, 2] = literal([[7, 8], [9, 10], [11, 12]])
y : int32[2, 2] = contract["n m, m k -> n k"](a, b)
}
// uv run main.py run book/book.cat outer
outer() -> (y: int32[3, 4]) {
u : int32[3] = literal([1, 2, 3])
v : int32[4] = literal([4, 5, 6, 7])
y : int32[3, 4] = contract["n, m -> n m"](u, v)
}
Together with the Op wrapper and the Function and Module structures for composition, these dataclasses comprise tensors.py—the complete catform abstract syntax tree (AST).
Computation Graphs
Every catform function we have written—dot_product, matrix_apply, and the op examples like transpose and col_sum—has the same shape: a list of named operations, each consuming some outputs from earlier ones. This structure has a natural visual representation called a string diagram. Each operation is a box; each typed tensor flowing between operations is a wire. Reading top to bottom traces execution order. The dotted boundary marks a function—nesting one inside another is composition.
function(t: T, x: X) -> (z: Z) {
(u,v): (U,V) = f(t)
w : W = g(x)
y : Y = h(t,u,v,w)
z : Z = i(y)
}
Data enters a catform function in two ways. Arguments are declared in the function signature and supplied by the caller—in the diagram, t and x pierce the top boundary. Introductions—values created by literal(...) or random(...)—originate inside the function with no incoming wire.
Two structural features are free, requiring no operation. A wire can fork: t feeds both f and h—the same tensor is available to multiple consumers without copying. And a box can emit multiple outputs: f produces both u and v. Neither forking nor multi-return requires an op; they are structural features of the language.
Independence is visible. Boxes f and g share no wires; they can execute in either order, or simultaneously. Parallelism is not annotated—it is read off the topology.
Every edge only flows downward—the diagram is a directed acyclic graph. Each wire name appears as an output of exactly one box, a property called static single assignment. The correspondence between a catform function and its string diagram is exact: every line in the code is a box, every named intermediate is a wire, and the top-to-bottom reading order matches the textual order up to reorderings that respect data dependencies.
Composition and Flattening
A single catform function is a straight-line program. But a .cat file contains a module—a collection of named functions that can reference each other. Two composition constructs allow structuring a computation across multiple functions:
call[f](x_1, ..., x_n)invokes the functionfdefined elsewhere in the module, passing arguments and receiving its returns.loop[f, count](args)is a macro that expands intocountsequential invocations off. For example, the loop
y : T = loop[f, 3](y, s, w)
expands to
y_1 : T = call[f](y, s, w.0)
y_2 : T = call[f](y_1, s, w.1)
y_3 : T = call[f](y_2, s, w.2)
The argument whose name matches the output (y) is threaded—SSA renaming chains each iteration’s output to the next iteration’s input (y → y_1 → y_2 → y_3). Arguments of dict type (*) are indexed per iteration via .0, .1, .2. All other arguments are static. Subsequent references to y are rewritten to y_3.
Neither call nor loop is an op type—they are macros, not computations. They organize the program for human comprehension: call factors repeated patterns into named functions; loop expresses iterated structure without copy-pasting. Both disappear under flattening: every call is inlined, every loop is unrolled, producing a single function called main—a straight-line sequence of the seven op types.
Two details remain before we put these operations to work: the concrete numeric formats that replace real and int, and a machine to run them.
Numerics
Everything above used real and int as mathematical types. A machine has a finite number of bits, so every numeric value it stores must fit in a fixed-width binary word. The consequences for integers are straightforward: a -bit two’s complement integer represents values in . Overflow wraps; there is no approximation within range. In catform, int32 denotes a -bit signed integer.
Real numbers are infinitely precise, so approximation is unavoidable. To represent them, we must choose a numeric representation. The standard format is given by the IEEE 754 floating-point specification. We can derive this representation as follows. Given a real number , we can first extract its — if non-negative and if negative—and rewrite it in terms of a non-negative real :
Then, let denote the unique integer such that
This lets us further deconstruct as
Note that by the choice of , is always in the interval . We make the substitution , where is a non-negative real number in the interval and hence can be expressed in binary as . This gives us the unique decomposition used by floating point numbers:
The triple (sign, exponent, mantissa) uniquely characterizes a floating-point value. A floating-point type is determined by choosing how many bits to allocate to each component.
A given floating point representation has a fixed number —always a power of —of bits. It always allocates bit to the sign, and then assigns for the exponent and for the mantissa. Assigning more bits to the exponent allows the representation to capture a wider range of numbers—much larger or much smaller—while assigning more to the mantissa gives higher precision—more binary digits before rounding occurs. Many numerical formats—specified by , and —have been used in numerical computing. A given language model might use distinct formats in different parts of its computation. In the next chapter, we will see such choices in a full language model computation.
Catform rewrites—flattening, inlining, reordering—are mathematically equivalent restructurings that never alter the entry-level arithmetic or dtypes.
Tensor Programming
We now have all the ingredients: tensors and their types, seven operation types plus contraction, introductions for constants and random data, and concrete numeric dtypes. Before we turn to language models, we put all of these components together in a single computation—with no machine learning in sight. Inspired by the excellent 3Blue1Brown lecture, we compute the -volume of the unit -ball, defined as the set of points in with distance at most from the origin. Its volume is a classical quantity that at first grows and then shrinks as increases—the unit ball becomes vanishingly small in high dimensions:
| Name | Volume | ||
|---|---|---|---|
| segment | 2.00 | ||
| disk | 3.14 | ||
| ball | 4.19 | ||
| -ball | 4.93 | ||
| -ball | 5.26 | ||
| -ball | 2.55 | ||
| -ball | 2.37 * 10^{-40} |
The volume peaks at and then declines. The closed forms for even and odd dimensions are
Since both formulae have factorials in the denominator that eventually dominate the powers of , the volume converges to zero. We can estimate it empirically via Monte Carlo sampling. The function takes the dimension and sample count as parameters, then generates random data inside—showcasing introductions. For each of random points, we sample coordinates uniformly from the interval .
A point lies inside the ball when
The volume can be calculated entirely via catform operations:
// uv run main.py run book/book.cat ball_vol 5 10000
ball_vol(n: int32[], S: int32[]) -> (vol: f32[]) {
x : f32[n, S] = random(-1.0, 1.0)
// compute the square norm for every sample
sq : f32[n, S] = map[mul](x, x)
norm : f32[1, S] = fold["n S -> 1 S", sum](sq)
// compute the ratio of points in the ball
one : f32[] = literal(1.0)
one_t : f32[1, S] = tile["-> 1 S"](one)
mask : f32[1, S] = map[le](norm, one_t)
inside: f32[1, S] = map[f32](mask)
frac : f32[1, 1] = fold["1 S -> 1 1", mean](inside)
ratio : f32[] = view["1 1 ->"](frac)
// scale the ratio by the volume of the hypercube
two : f32[] = literal(2.0)
n_f : f32[] = map[f32](n)
side : f32[] = map[pow](two, n_f)
vol : f32[] = map[mul](ratio, side)
}
Line by line: random introduces an tensor of uniform samples. map[mul] squares every coordinate. fold[sum] collapses the axis, yielding a squared norm for each sample. map[le] tests whether each norm is at most —whether the point is inside the ball—producing a boolean. map[f32] casts to float so that fold[mean] can average across all samples, and view strips the unit axes to a scalar. The final three lines compute —the volume of the hypercube—and multiply to convert the fraction into a volume.
With and large , ball_vol returns approximately 3.14 (). The operations that will specify a language model are not specific to language models—they are general-purpose tensor operations. A forward pass through a transformer and a Monte Carlo integration are the same kind of thing: a straight-line composition of maps and folds.
Naive Execution
Our catform tensor operations and their use in .cat modules are specifications—they say what to compute, not how. To actually run the ops, we need an executor that lowers each op—translates it from an abstract description to a concrete function call in a tensor computing library. The two dominant frameworks for tensor computation at the time of writing are PyTorch and JAX. Pianola’s simplest executor lives in lower/simple.py and dispatches each op to a single call in either one.
The pattern-string specifiers used by view, fold, tile, and contract were directly inspired by einops, a library that provides framework-agnostic tensor operations with the same pattern-string interface. We use einops as the interop layer: since catform’s specifiers and einops’s patterns are the same language, the lowering is almost trivial—the pattern passes straight through. The complete dispatch from _resolve_op in simple.py maps each op type to its concrete implementation:
case View(pattern=p, axes=ax): return einops.rearrange(x, p, **ax)
case Map(function=f): return MAP_DISPATCH[f](backend, args)
case Fold(pattern=p, reduction=r): return einops.reduce(x, p, r)
case Tile(pattern=p, axes=ax): return einops.repeat(x, p, **ax)
case Gather(pattern=p): return data[indices]
case Scatter(pattern=p): return scatter(choice, values, indices, template, axis)
case Contract(pattern=p): return einops.einsum(a, b, p)
view, fold, tile, and contract pass their pattern strings directly to einops. gather lowers to the framework’s native advanced indexing, which both PyTorch and JAX handle identically. scatter is the one op that requires its own per-backend dispatch function, because the accumulation primitives differ: PyTorch provides index_add_ (in-place), while JAX provides .at[].add() (functional). Neither einops nor operator overloading can bridge this gap, so scatter gets a standalone function that cases on the backend:
def scatter(framework: Framework, values, indices, template, axis):
match framework:
case Framework.torch:
return torch.zeros_like(template).index_add_(axis, indices, values)
case Framework.jax:
return jnp.zeros_like(template).at[idx].add(values)
map ops split into three groups. Arithmetic binary operators (add, mul, sub, div, ge) use Python’s built-in operators, which both PyTorch and JAX overload identically. where is a ternary operator—it takes a boolean mask and two value tensors, returning entries from one or the other. This is how tensor programs handle conditional selection: every entry is computed for both alternatives, and where picks between them elementwise, which avoids control flow. Finally, unary real-valued functions (exp, cos, sin, silu, rsqrt) are the places where the framework’s namespace is needed. The list of these functions is extensible, without needing to touch the core op primitives: adding a new scalar function (say, tanh) can be done via one new slot in a dispatch table. The framework-specific slots are collected in a Backend dataclass:
@dataclass(frozen=True)
class Backend:
exp: Callable
cos: Callable
sin: Callable
silu: Callable
rsqrt: Callable
where: Callable
...
A PyTorch backend fills these with torch.exp, torch.cos, etc, while a JAX backend fills them with jnp.exp, jnp.cos, etc. The interpreter does not know or care which one it received—the .cat specification is portable across any backend that fills the same slots. The entire interpreter is four lines:
def run(ops: list[Op], env: dict[str, Any], backend: Backend) -> None:
for op in ops:
args = tuple(env[name] for name in op.inputs)
env[op.output] = _resolve_op(op, backend)(*args)
Walk the ops in order, look up each input by name, call the resolved function, store the result. No graph construction, no tracing, no runtime code generation—just a loop over operations. The same .cat file, the same run() function, two different Backend values: one for PyTorch, one for JAX. The results are identical up to rounding noise.
This is a line-by-line interpreter—each step materializes its result, making every intermediate inspectable.
We now have the mathematical language to describe, and the computational translation to execute, the full computation of a language model.