Catching PyTorch's Performance
Mar 11st 2026
As Sheaf V2 moves closer to release, it's time to test it on larger models than BareGPT, my toy generative transformer.
Last week, I ported Karpathy's NanoGPT to Sheaf and benchmarked it against PyTorch. The first results were humbling:
| Configuration | Wall time | tk/s | vs PyTorch |
|---|---|---|---|
| Interpreter only (no IREE) | ~90s | ~1.1 | 108× slower |
| Sheaf build (AOT) | 7.64s | ~13 | 9.2× slower |
| PyTorch (CPU, reference) | 0.83s | ~120 | baseline |
I had been confident that IREE and LLVM would handle the heavy lifting, and had mostly focused on the language and the framework itself. These numbers were a wake-up call. They triggered a shift from design work to performance tuning, which is still ongoing as I write this. Spoiler: Sheaf's performance is now a lot closer to PyTorch.
New tooling
To start, I added two features to the CLI:
--blame, inspired bysystemd-analyze blame, which profiles Sheaf execution and shows where time is spentSHEAF_JIT_PROFILE=1, which instruments the IREE dispatch path: buffer marshalling, GPU call, and output transfer
Some issues were immediately visible:
- Too few functions were being compiled, because side-effects and higher-order functions blocked AOT compilation
- Interpreter performance on a real model (hundreds of operations per forward pass) was so poor that the dual interpreted + AOT mode had to be replaced with a transparent JIT compiler, reusing the AOT logic
How Sheaf's compiler works
Before diving into the optimizations, a brief overview of Sheaf's compilation pipeline.
1. Parser. Sheaf source is transformed into an AST of S-expressions. Unlike a "pure" Lisp such as Scheme, which is its own AST, Sheaf has syntactic sugar (threading macros, destructuring, keyword args), so we need a parser to normalize everything into a uniform tree:
(defn gpt-forward [x params config] ...)
↓
SheafValue::List([Symbol("defn"), Symbol("gpt-forward"), ...])
2. Compiler. The AST is then lowered into CompiledExpr, an internal IR. This stage expands macros, recognizes special forms (defn, let, if, fn), resolves symbols, and registers functions in the compiler registry. The result is a tree of typed enum variants:
enum CompiledExpr {
Integer(i64),
Float(f64),
Symbol(String),
FunctionCall(String, Vec<CompiledExpr>),
Let { bindings, body },
Lambda { params, body },
If { cond, then, else_ },
GetTupleElement { param, indices },
Vector(Vec<CompiledExpr>),
Keyword(String),
...
}
3. Transforms. Several passes rewrite the CompiledExpr tree before codegen: dictionary accesses become positional tuple lookups, known function calls are inlined, reduce loops are unrolled into chains of Let bindings (required for autodiff), and constant subexpressions are evaluated at compile time.
4. Codegen. Each CompiledExpr node maps to one or more StableHLO operations. The codegen also computes output shapes for each tensor operation, threading shape information through the entire graph.
5. Compilation. Finally, iree-compile lowers the StableHLO module into a VMFB file for a specific backend (CPU, Metal, or CUDA).
The marshalling wall
With the profiling tools in place, the first bottleneck became obvious: marshalling overhead. Every time the interpreter calls into IREE, it must convert each tensor argument into an IREE buffer view. For NanoGPT's forward pass, that means 27 parameter tensors, converted and uploaded on every single call.
For 1000 tokens of autoregressive generation, that's 27,000 buffer allocations and roughly 8GB of cumulative memory copies. The GPU was spending more time waiting for data than computing.
The fix was a fingerprint-based buffer view cache. Computing a hash per function proved very compute-expensive, so each tensor argument is now fingerprinted using Arc pointer identity: same Arc means same data, which means the previously created buffer view can be reused. This is O(1) with no false positives, since Sheaf's values are immutable.
The cache brought inference from 11.3s down to 7.7s for 1000 tokens (a 32% improvement), but a gap remained.
Deep copies in the interpreter
Profiling the remaining gap revealed a surprising cost: the interpreter was deep-cloning entire tensors on every function call and environment lookup. A GPT forward pass with 27 parameter tensors meant copying ~8MB of weights repeatedly, in what should have been zero-cost reference passing.
The fix was straightforward: wrap tensor data in Arc<ArrayD<f32>>. Function calls, let bindings, and dictionary lookups now bump a reference count instead of copying data. The impact was dramatic:
| Version | 1000 tokens | vs JAX V1 |
|---|---|---|
| Before optimizations | 11.3s | 1.7× slower |
| + buffer cache | 7.7s | 1.1× slower |
| + Arc tensors | 4.85s | 1.4× faster |
| JAX V1 (Python + JAX) | 6.8s | baseline |
For BareGPT, my small Transformer, Sheaf V2 now beats the Sheaf V1 Python+JAX implementation, in a 3MB binary vs a 500MB runtime.
The next step was to scale up to GPT-2 124M and benchmark against PyTorch directly.
JIT compilation
The AOT approach (sheaf build) required a manual compilation step and couldn't handle higher-order functions or closures that capture runtime state. The JIT compiler replaced it with transparent on-first-call compilation.
When a pure function is called for the first time, the JIT traces its arguments to infer concrete types (no more defparams!) and automatically runs the full compiler pipeline. If the function's source changes, the hash changes, and the JIT recompiles. The iree-compile toolchain is auto-downloaded on first use (Zig-style), so there's nothing to install.
f32 everywhere
A less obvious optimization: the interpreter was using f64 internally (Rust and ndarray's default float type), while IREE and the GPU work exclusively in f32. Every tensor crossing the interpreter-IREE boundary was being cast, element by element.
Switching Value::Float from f64 to f32 eliminated this conversion entirely. The impact on inference was a 31% speedup (1080ms to 746ms for 100 tokens), and training dropped from 1.26s to ~670ms per step.
DeviceBuffer and the fast path
The final optimization addressed what happens between consecutive IREE calls. When adam-step produces updated parameters on the GPU, those tensors should flow directly into the next train-step call without ever touching host memory.
DeviceBuffer wraps an IREE buffer view that stays on the GPU. When all inputs to an IREE call are already DeviceBuffers, a fast path skips the buffer cache entirely: no mutex lock, no fingerprint check, just a pointer retain and push. For multi-step training, this means step 1 pays the full marshalling cost (all tensors are new), but steps 2+ run with near-zero buffer overhead.
IREE dispatch profile (3 calls, 1461.5ms total):
flatten: 0.0ms ( 0.0%)
buffers: 416.7ms (28.5%) [hits: 908, misses: 454]
call: 1044.7ms (71.5%)
output: 0.0ms ( 0.0%)
All 454 misses occur on step 1. Steps 2 and 3 are 100% cache hits.
Where things stand
Here is the full journey on GPT-2 124M training so far (batch=1, block=64, Metal GPU):
| Version | Per step (warm) | vs PyTorch MPS |
|---|---|---|
| Interpreter only | ~11s | 86× |
| + JIT | 1.26s | 10× |
| + f32 interpreter | ~670ms | 5× |
| + DeviceBuffer fast path | 533ms | 4.2× |
| PyTorch (eager, MPS) | 128ms | baseline |
From 11 seconds to 533 milliseconds. The remaining 4.2× gap is now dominated by the GPU kernel itself (71.5% of dispatch time). This is IREE's Metal backend vs Apple's MPS, which is a different compiler targeting the same hardware and is outside Sheaf's control.
On inference, the picture is different. For 100 tokens of autoregressive generation, 94% of the time is spent in the GPU call. Buffer overhead is negligible. The gap with PyTorch (746ms vs 390ms) is entirely in kernel quality.
The optimizations described here will continue in a follow-up article, where I'll cover the autodiff pipeline and training step fusion.