Symbolic Autodiff Hits a Wall
Mar 28th 2026
I concluded that previous article with a sense of satisfaction with the performance improvements in Sheaf. It ended with GPT-2 training at 114ms per step thanks to a fused VMFB dispatch containing forward, backward, gradient clipping, and Adam.
That fusion works because the autodiff engine can produce correct gradients for the entire GPT forward pass. Unfortunately, the original symbolic AD engine, which I discussed in article 3, worked beautifully for XOR but hit a wall on GPT-2.
The limits of symbolic differentiation
Sheaf's autodiff engine, grad_simplified, is a symbolic differentiator. Given
a CompiledExpr tree and a variable name, it returns a new CompiledExpr tree
representing the derivative. This makes the implementation straightforward:
parse the tree, compute the derivative for each node, and voilà:
d/dx (* x y) = (+ (* 1 y) (* x 1)) -> y
d/dx (sin x) = (cos x)
This, of course, is slow, but the output MLIR is optimized by IREE so I saw no red flags when writing the autodiff. Besides, I was somewhat relieved that implementing it wasn't as difficult as I had feared.
For the small XOR MLP network, this worked perfectly. The gradient expression is small, the simplifier cleans up trivial cases, and the result compiles to correct StableHLO.
But GPT-2's forward pass is much larger, and the symbolic diff quickly showed
that it simply uses too much memory. After inlining and unrolling the 6
transformer blocks, the CompiledExpr tree contains roughly 3000 nodes. The
problem is that CompiledExpr is a tree, so each node owns its children.
When the chain rule produces:
d/dx (f * g) = f' * g + f * g'
It clones the entire subtrees f and g into both branches. For GPT with
58 parameter leaves across 6 layers of nesting, every chain rule application
duplicates everything below the current node. The gradient expression grows
exponentially. The process consumed 165 GB of RAM and nearly crashed my machine
before I could kill it.
The limitation here is architectural, and even has a name: expression swell. It's the reason every production AD system avoids operating on raw syntax trees.
One might ask: why not just JIT-compile the gradient expression too? The answer is that the tree never makes it to Sheaf codegen. Constructing the gradient expression in memory is the bottleneck: the clones and heap allocations happen at IR level, before any StableHLO is emitted. By the time the codegen would see the gradient, the process has already run out of memory.
One way to solve this would be to emit StableHLO directly during the chain rule
traversal, building a "tape" of MLIR operations instead of a CompiledExpr
tree. This is essentially what PyTorch does. But this approach couples the AD
logic with code generation, which creates a monolith. I prefer modular
architectures for long-term maintainability, so Sheaf's keeps these separate: AD
is a transformation from IR to IR, and codegen then translates IR to MLIR.
How other frameworks handle this
PyTorch and JAX both avoid expression swell by operating on graphs with sharing, not trees.
PyTorch uses a runtime tape. During the forward pass, each operation records
itself and its backward function. loss.backward() walks the tape in reverse,
computing concrete gradient tensors. Each intermediate value exists once in
memory, so there is no duplication like in a naive symbolic autodiff.
JAX uses a compiled IR called jaxpr where each variable is defined exactly
once. jax.grad then transforms a forward jaxpr into a backward jaxpr.
Variables are names, not trees, so referencing a variable doesn't copy anything.
What's relevant here is that in both systems, sub-expressions are shared by
reference. When the chain rule needs f in both f' * g and f * g', it
references the same node or register in both places. The gradient is the same
size as the forward pass, so it stays O(n) instead of O(2^n).
Sheaf needed the same property for its autodiff, but I couldn't simply copy PyTorch's tape (no runtime dispatch) nor jaxpr (Sheaf has its own IR).
The question was how to get sharing in a tree-structured CompiledExpr.
ANF: naming every sub-expression
I have read a lot about compilers since I started writing Sheaf. The main courses I follow are CS6120 from Cornell university. One of the discussed topic is Administrative Normal Form (ANF), a standard technique from compiler theory which is nicely summarized by Kris Micinski in this presentation.
ANF flattens a nested expression into a sequence of let bindings where every
right-hand side is a simple operation.
Before ANF (tree):
(* (+ a b) (- (+ a b) c))
;; *
;; / \
;; + -
;; / \ / \
;; a b + c <- (+ a b) duplicated
;; / \
;; a b
After ANF (flat Let chain):
(let [t0 (+ a b)
t1 (- t0 c)
t2 (* t0 t1)]
t2)
;; t0 appears twice, but it's just a name
The ANF form is structurally identical to SSA or MLIR registers. Each binding
computes one operation and names the result. When the chain rule needs t0 in
two places, it references the same symbol, like JAX and PyTorch do.
I went back to the gptel buffer and soon enough, I had a prototype to_anf
transform function. It walks the tree and lifts every non-trivial sub-expression
into a let binding with a fresh __anf_N name. Trivial expressions (symbols,
literals, GetTupleElement) stay inline. For GPT-2, the tree of ~3000 nodes
becomes a flat sequence of 556 forward bindings, one per operation.
Reverse-mode AD on ANF
With the expression in ANF, reverse-mode AD is now a linear walk. The algorithm walks the 556 bindings in reverse order, accumulating contributions for each variable:
;; Forward:
;; t0 = a + b, t1 = t0 - c, t2 = t0 * t1
;;
;; Reverse:
;; dt2 = 1.0 (seed)
;; dt1 += dt2 * t0 (from t2 = t0 * t1)
;; dt0 += dt2 * t1 (from t2 = t0 * t1)
;; dt0 += dt1 * 1 (from t1 = t0 - c)
;; da += dt0 (from t0 = a + b)
;; db += dt0 (from t0 = a + b)
All gradients are computed in one pass, with O(n) expression size. For GPT-2,
reverse_grad produces 1639 backward bindings from the 556 forward ones.
More importantly: the expression that previously required 165 GB now fits in a few megabytes of IR, which means the ANF actually solved the memory problem.
The symbolic engine is staying
I was hoping to get rid of the symbolic code entirely, but the JIT can't compile all VAG calls since some are effectful.
So the symbolic engine still serves as the interpreter fallback. When the JIT
can't compile a VAG call, the interpreter runs grad_simplified on the inlined
expression. For small functions, this works fine. For anything larger, the JIT
with ANF-based reverse AD is the only practical option, not because it's faster,
but because the symbolic engine literally cannot produce a gradient expression
that fits in memory.
What's next
The ANF-based reverse AD handles GPT-2. The next frontier is data-dependent
control flow: conditional branching, dynamic shapes, variable-length sequences.
reverse_grad currently does not handle If or While, the same limitation
that exists in JAX (which requires jax.lax.cond instead of Python if).
Sheaf's where already has a VJP, and most ML code can be written without
branching. But supporting If would extend the AD to a wider class of programs
without requiring users to restructure their code.
Additionally, I still routinely hit bugs and panics in the autodiff. Sheaf's development has become a lot more complex with its progress, and fixing bugs requires both a lot of thinking and a lot of testing - and my OpenRouter tokens can only help so much. The benchmarks and regression tests suites have evolved to reflect this.