Ditching Enzyme: A Case for Symbolic Autodiff in Sheaf V2
Feb 15th 2026
The Rust compiler for Sheaf is progressing well. The parser handles S-expressions correctly, and the first StableHLO emission works. Pattern matching in Rust proved to be the right choice: exhaustiveness checking catches missing cases at compile time.
Before going further, I need to validate the autodiff strategy. The V2 architecture targets IREE, which means I need gradients in StableHLO/Linalg, not through JAX. The obvious candidate was Enzyme-MLIR.
The Enzyme Detour
Enzyme is a compile-time autodiff engine that operates on LLVM IR and MLIR. It checks what I want for Sheaf: MLIR-native, designed for ahead-of-time compilation, and supposedly lightweight.
I spent time compiling it from source, fixing CMake issues to generate a shared library, and eventually got an 80MB libEnzymeMLIR-23.dylib.
The problem became clear when trying to integrate it with IREE: IREE doesn't support dynamic plugin loading. To use Enzyme, IREE would need to be rebuilt with Enzyme statically linked, then included in Sheaf. Embedding such a fast-moving project inside Sheaf would turn it into a very large and tightly coupled monolith, with long build times and high maintenance cost. For a project aiming to minimize "bloat" and stick to "do one thing and do it well", this is an anti-pattern.
I briefly considered alternatives: jax.grad, torch.autograd and the TinyGrad autodiff all bring Python back into the compilation pipeline, negating the entire point of V2.
The Minimal Path
Facing these limitations, the question became: do I need such a large code base to perform differentiation? What does Sheaf actually need?
Enzyme is designed to differentiate arbitrary C++ code with pointers, aliasing, and side effects. Sheaf has none of that. The code is purely functional, immutable, and operates on a small set of mathematical operations: matrix multiplication, element-wise arithmetic, activations, and reductions.
Besides, I don't need an autodiff engine that optimizes: I need one that generates mathematically correct gradients. IREE will optimize afterward through kernel fusion and vectorization.
This suggested a different approach: symbolic differentiation on the AST, before StableHLO emission. Since Sheaf code is data (S-expressions), I can transform it directly using standard calculus rules.
This also gives me an advantage compared to Enzyme: operating straight on the semantic structure of the program (the AST), not on the compiled MLIR. Enzyme works on LLVM IR, where it must reconstruct intent from low-level operations. In Sheaf, the intent is explicit in the S-expression tree. A multiplication is (* a b), not a sequence of pointer dereferences and SIMD instructions.
Having an internal and standalone autograd also solves one of the main architectural issue of the V2: differentiation in the interpreter, which would have otherwise depended on Enzyme.
Proof of Concept in Python
To validate this, I built a minimal autodiff in Python. Start with a simple expression:
x = Var("x")
y = Var("y")
expr = Add(Mul(x, y), x) # f(x,y) = x*y + x
Apply the chain rule symbolically:
df_dx = grad(expr, "x") # => y + 1
df_dy = grad(expr, "y") # => x
Generate StableHLO:
func.func @forward(%x: tensor<f32>, %y: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.multiply %x, %y : tensor<f32>
%1 = stablehlo.add %0, %x : tensor<f32>
return %1 : tensor<f32>
}
func.func @grad_x(%x: tensor<f32>, %y: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.add %y, %c1 : tensor<f32>
return %0 : tensor<f32>
}
Compile with IREE, execute, verify: df/dx(2,3) = 4, df/dy(2,3) = 2. Correct.
Scaling to a Real World Example
Validating the engine requires more than differentiating a simple expression, so I moved to the XOR MLP (2->8->1 architecture, matching the existing Sheaf "Hello world" example). This requires:
- Matrix multiplication with proper transpose for gradients
- Bias broadcast ([8] -> [4×8]) and reduction ([4×8] -> [8]) for gradient accumulation
- Activations (ReLU, Sigmoid) with their derivatives
- MSE loss with mean reduction
- Value-and-grad returning both loss and all parameter gradients
The implementation grew to about 400 lines of Python: AST nodes, shape inference, gradient rules, and StableHLO emission with broadcast handling through stablehlo.broadcast_in_dim.
For the XOR MLP, this generates a single train_step function that computes:
(loss, dL/dW1, dL/db1, dL/dW2, dL/db2) = train_step(x, W1, b1, W2, b2, y)
This compiles to an 18KB VMFB, but more importantly, it outputs the correct shapes and values.
The Full Training Loop
The next step is to compile the entire training loop into the VMFB, not just a hard-coded single pass.
Instead of calling train_step 500 times from Python, this requires generating MLIR with scf.for that contains forward, backward, and stochastic gradient descent updates:
func.func @train(%x, %y, %W1_init, %b1_init, %W2_init, %b2_init, %lr, %epochs) {
%final = scf.for %epoch = 0 to %epochs {
%loss, %grads = call @train_step(...)
%params_new = sgd_update(%params, %grads, %lr)
scf.yield %params_new
}
return %final
}
The interface is not yet user-friendly, but this can already be called "on-device training":
~/standalone_train$ ll -h
total 1720
-rwxr-xr-x 1 damien staff 836K Feb 9 01:33 iree-run-module*
-rw-r--r-- 1 damien staff 18K Feb 10 18:41 training_loop.vmfb
~/standalone_train$ ./iree-run-module \
--module=training_loop.vmfb \
--device=local-task \
--function=train \
--input="4x2xf32=0,0,0,1,1,0,1,1" \
--input="4x1xf32=0,1,1,0" \
--input="2x8xf32=0.05723,-0.00872,0.04017,0.09566,-0.01383,0.06476,-0.03770,0.04296,0.06901,-0.09234,0.03320,0.07782,-0.04273,0.00332,0.00073,0.00774" \
--input="8xf32=0,0,0,0,0,0,0,0" \
--input="8x1xf32=-0.13197,0.06499,0.00506,-0.04559,0.09581,-0.05592,0.04567,0.02028" \
--input="1xf32=0" \
--input="f32=0.7" \
--input="500"
EXEC @train
result[0]: hal.buffer_view
2x8xf32=[0.0277724 0.0014524 0.0427165 0.0880606 0.00132029 0.0561985 -0.0111438 0.0482478][0.0442278 -0.0827338 0.0352603 0.0713262 -0.0288227 -0.00330342 0.0230232 0.012129]
result[1]: hal.buffer_view
8xf32=0.000628438 -0.000299154 -0.000169139 -5.71229E-05 -0.000200321 -0.000927506 -0.00394388 -0.000171253
result[2]: hal.buffer_view
8x1xf32=[-0.157072][0.206765][0.0120207][-0.0359833][0.224716][0.187995][0.695675][0.0252195]
result[3]: hal.buffer_view
1xf32=-0.00244113
result[4]: hal.buffer_view
f32=0.997798
I am happy with this result. The compiled training loop is 18KB, the entire standalone package (runtime + VMFB) fits in 850KB, and the network output converges correctly (final prediction ~= 0.998).
Of course, this is still a Proof of Concept, but this validates that a custom differentiation engine is feasible.
Performance
Since XLA is highly optimized and probably uses all kind of performance tricks, I didn't expect much from my naive autograd.
I ran a 500 epochs benchmark on the MLP:
- Sheaf V1 (Python loop + JAX): 1.84 seconds for 500 epochs.
- Sheaf V2 (symbolic autodiff + compiled loop): 0.036 seconds for 500 epochs.
In this specific setup, V2 runs 51× faster than the V1 configuration. However, this isn't that surprising as V2 compiles the entire training loop ahead of time, while V1 has a huge dispatch latency as it calls JAX from Python 500 times. But it demonstrates the overhead eliminated by AOT compilation.
Conclusion
This POC validates the approach for simple networks. Having a standalone differentiation engine will allow me to keep the SDK size minimal with no LLVM dependency in the distributed package, and no 80MB autodiff plugin. However, several questions remain open.
-
The current implementation handles standard operations (matmul, activations, element-wise ops) on static shapes. Control flow dependent on runtime values, non-differentiable operations, and custom gradient definitions are not yet supported.
-
I use reverse-mode autodiff, which emits explicit gradient functions for each parameter. The implementation does not yet include gradient checkpointing or recomputation strategies for memory efficiency. Gradients are computed directly from the symbolic AST and emitted as StableHLO. This works for POC scale, but larger models will require checkpointing.
-
Testing on XOR validates correctness but does not demonstrate how the approach scales to deeper architectures. BareGPT (the full Sheaf generative transformer example) will be the real test, and will also allow me to test the sharding with reduce-all, which is my next POC.
The Rust implementation will need to address these constraints as it extends to the full Sheaf feature set.