Skip to content

Ditching Enzyme: A Case for Symbolic Autodiff in Sheaf V2

Feb 15th 2026

The Rust compiler for Sheaf is progressing well. The parser handles S-expressions correctly, and the first StableHLO emission works. Pattern matching in Rust proved to be the right choice: it doesn't prevent panics, but exhaustiveness checking catches missing cases at compile time.

Before going further, I need to validate the autodiff strategy. The V2 architecture targets IREE, which means I need gradients in StableHLO/Linalg, not through JAX. As I mentioned in the first blog post, the obvious candidate was Enzyme-MLIR, because it provided a convenient black-box for the autodiff.

The Enzyme Detour

Enzyme is a compile-time autodiff engine that operates on LLVM IR and MLIR. It checks what I want for Sheaf: MLIR-native, designed for ahead-of-time compilation, and supposedly lightweight.

I spent quite some time compiling it from source, fixing CMake issues to generate a shared library on macOS, and eventually got a whoping "lightweight" 80MB libEnzymeMLIR-23.dylib.

The problem became clear when trying to integrate it with IREE: IREE doesn't support dynamic plugin loading. To use Enzyme, IREE would need to be rebuilt with Enzyme statically linked, then included in Sheaf. Embedding such a fast-moving project inside Sheaf would turn it into a very large and tightly coupled monolith, with long build times and high maintenance cost.

For a project aiming to minimize bloat and stick to do one thing and do it well, this is an anti-pattern.

I briefly considered alternatives: jax.grad, torch.autograd and the TinyGrad autodiff, but they all bring Python back into the compilation pipeline, negating the entire point of V2. Python shall be no more.

The Minimal Path

Facing these limitations, the question became: do I need such a large code base to perform differentiation? What does Sheaf actually need?

Enzyme is designed to differentiate arbitrary C++ code with pointers, aliasing, and side effects. Sheaf has none of that. The code is purely functional, immutable, and operates on a small set of mathematical operations: matrix multiplication, element-wise arithmetic, activations, and reductions.

Besides, I don't need an optimizin autodiff engine: I need one that generates mathematically correct gradients. IREE will optimise the MLIR afterward through DCE, kernel fusion and vectorization.

This suggested a different approach: maybe I could try an easy (albeit very slow) symbolic differentiation on the AST, before StableHLO emission, then let IREE optimize the mess. Since Sheaf code is data (S-expressions), so I can transform it directly using standard calculus rules, which means a fairly straightforward implementation.

This also gives me an advantage compared to Enzyme: operating straight on the semantic structure of the program (the AST), not on the compiled MLIR. Enzyme works on LLVM IR, where it must reconstruct intent from low-level operations. In Sheaf, the intent is explicit in the S-expression tree. A multiplication is (* a b), not a sequence of pointer dereferences and SIMD instructions, which means I don't have to be an LLVM or XLA engineer to understand and debug it.

Having an internal and standalone autograd also solves one of the main architectural issue of the V2: differentiation in the interpreter, which would have otherwise depended on Enzyme.

Proof of Concept in Python

To validate this, I built a minimal autodiff in Python, which reminded me of my old Andrew Ng course. Start with a simple expression, a product:

x = Var("x")
y = Var("y")
expr = Add(Mul(x, y), x)  # f(x,y) = x*y + x

Then differenciate it:

df_dx = grad(expr, "x")  # => y + 1
df_dy = grad(expr, "y")  # => x

Then grep the StableHLO specifications to generate something that IREE can run:

func.func @forward(%x: tensor<f32>, %y: tensor<f32>) -> tensor<f32> {
  %0 = stablehlo.multiply %x, %y : tensor<f32>
  %1 = stablehlo.add %0, %x : tensor<f32>
  return %1 : tensor<f32>
}

func.func @grad_x(%x: tensor<f32>, %y: tensor<f32>) -> tensor<f32> {
  %0 = stablehlo.add %y, %c1 : tensor<f32>
  return %0 : tensor<f32>
}

Compile with IREE, execute, verify: df/dx(2,3) = 4, df/dy(2,3) = 2, and voilà! My first StableHLO grad, very far from a generic implementation, but already gratifying.

Scaling to a Real World Example

Validating the engine requires more than differentiating a simple expression, so I moved to the simple XOR MLP Sheaf example and see if I could differenciate it with my rudimentary autograd. This requires the differentiation of a lot more operations:

  • Matrix multiplication with proper transpose for gradients
  • Bias broadcast ([8] -> [4×8]) and reduction ([4×8] -> [8]) for gradient accumulation
  • Activations (ReLU, Sigmoid) with their derivatives
  • MSE loss with mean reduction
  • Value-and-grad returning both loss and all parameter gradients

The implementation grew to about 400 lines of Python: AST nodes, shape inference, gradient rules, and StableHLO emission with broadcast handling through stablehlo.broadcast_in_dim. Claude helped write the boilerplate, but I'm okay relying on AI assistance to deleguate the execution (not the thinking).

For the XOR MLP, this generates a single train_step function that computes:

(loss, dL/dW1, dL/db1, dL/dW2, dL/db2) = train_step(x, W1, b1, W2, b2, y)

This compiles to an 18KB VMFB, but more importantly, it outputs the correct shapes and values.

The Full Training Loop

The next step is to compile the entire training loop into the VMFB, not just a hard-coded single pass. Instead of calling train_step 500 times from Python, this requires generating MLIR with scf.for that contains forward, backward, and stochastic gradient descent updates (Adam will wait):

func.func @train(%x, %y, %W1_init, %b1_init, %W2_init, %b2_init, %lr, %epochs) {
  %final = scf.for %epoch = 0 to %epochs {
    %loss, %grads = call @train_step(...)
    %params_new = sgd_update(%params, %grads, %lr)
    scf.yield %params_new
  }
  return %final
}

The interface is not yet user-friendly, but I can proudly watch my own standalone autograd perform the differenciation:

~/standalone_train$ ll -h
total 1720
-rwxr-xr-x  1 damien  staff   836K Feb  9 01:33 iree-run-module*
-rw-r--r--  1 damien  staff    18K Feb 10 18:41 training_loop.vmfb
~/standalone_train$ ./iree-run-module \
  --module=training_loop.vmfb \
  --device=local-task \
  --function=train \
  --input="4x2xf32=0,0,0,1,1,0,1,1" \
  --input="4x1xf32=0,1,1,0" \
  --input="2x8xf32=0.05723,-0.00872,0.04017,0.09566,-0.01383,0.06476,-0.03770,0.04296,0.06901,-0.09234,0.03320,0.07782,-0.04273,0.00332,0.00073,0.00774" \
  --input="8xf32=0,0,0,0,0,0,0,0" \
  --input="8x1xf32=-0.13197,0.06499,0.00506,-0.04559,0.09581,-0.05592,0.04567,0.02028" \
  --input="1xf32=0" \
  --input="f32=0.7" \
  --input="500"
EXEC @train
result[0]: hal.buffer_view
2x8xf32=[0.0277724 0.0014524 0.0427165 0.0880606 0.00132029 0.0561985 -0.0111438 0.0482478][0.0442278 -0.0827338 0.0352603 0.0713262 -0.0288227 -0.00330342 0.0230232 0.012129]
result[1]: hal.buffer_view
8xf32=0.000628438 -0.000299154 -0.000169139 -5.71229E-05 -0.000200321 -0.000927506 -0.00394388 -0.000171253
result[2]: hal.buffer_view
8x1xf32=[-0.157072][0.206765][0.0120207][-0.0359833][0.224716][0.187995][0.695675][0.0252195]
result[3]: hal.buffer_view
1xf32=-0.00244113
result[4]: hal.buffer_view
f32=0.997798

I am happy with this result. The compiled training loop is 18KB, the entire standalone package (runtime + VMFB) fits in 850KB, and the network output converges correctly (final prediction ~= 0.998).

Of course, this is still a Proof of Concept, but this validates that a custom differentiation engine is feasible.

Performance

Since XLA is highly optimized and probably uses all kinds of performance tricks, I didn't expect much from my naive autograd.

I ran a 500 epochs benchmark on the MLP:

  • Sheaf V1 (Python loop + JAX): 1.84 seconds for 500 epochs.
  • Sheaf V2 (symbolic autodiff + compiled loop): 0.036 seconds for 500 epochs.

Wait, 51x faster? There is simply no way I should even be close to JAX with my naive autograd. However, this isn't that surprising as V2 compiles the entire training loop ahead of time, while V1 has a huge dispatch latency as it calls JAX from Python 500 times. Once I add a JIT profiler to monitor the cache misses and the overhead, I will know if this actually scales. Right now, it already demonstrates the overhead eliminated by AOT compilation.

Conclusion

This POC validates the approach for simple models. Having a standalone differentiation engine will allow me to keep the SDK size minimal with no LLVM dependency in the distributed package, and no 80MB autodiff plugin.

However, several questions remain open.

  • The current implementation handles standard operations (matmul, activations, element-wise ops) on static shapes. Control flow dependent on runtime values, non-differentiable operations, and custom gradient definitions are not yet supported.

  • I use reverse-mode autodiff, which emits explicit gradient functions for each parameter. The implementation does not yet include gradient checkpointing or recomputation strategies for memory efficiency. Gradients are computed directly from the symbolic AST and emitted as StableHLO. This works for POC scale, but larger models will require checkpointing.

  • Testing on XOR validates correctness but does not demonstrate how the approach scales to deeper architectures. BareGPT (the full Sheaf generative transformer example) will be the real test, and will also allow me to test the sharding with reduce-all, which is my next POC.

The Rust implementation will need to address these constraints as it extends to the full Sheaf feature set.