Catching PyTorch's Performance
Mar 11th 2026
As Sheaf V2 moves closer to release, it's time to test it on larger models than BareGPT, my toy generative transformer.
Last week, I ported Karpathy's NanoGPT to Sheaf and benchmarked it against PyTorch. The first results were humbling:
| Configuration | Wall time | tk/s | vs PyTorch |
|---|---|---|---|
| Interpreter only (no IREE) | ~90s | ~1.1 | 108× slower |
| Sheaf (JIT) | 7.64s | ~13 | 9.2× slower |
| PyTorch (CPU, reference) | 0.83s | ~120 | baseline |
So far, I had been confident that IREE and LLVM would handle the heavy lifting, and had mostly focused on getting the Sheaf language right. These numbers were a wake-up call. They triggered a shift from design work to performance tuning, which is still ongoing as I write this. Spoiler: Sheaf's performance is now quite closer to PyTorch.
New tooling
To start, I added two features to the CLI, which I wanted for quite some time but never took the time of implementing:
--blame, inspired bysystemd-analyze blame, which profiles Sheaf execution and shows where time is spent--jit-profile, which instruments the IREE dispatch path: buffer marshalling, GPU call, and output transfer
Something became immediately visible: too few functions were being compiled, because side-effects and higher-order functions blocked AOT compilation. Besides, interpreter performance for the HOF functions was so poor that clearly, I would have to JIT a lot more Sheaf code, which is painful for differentiation.
How Sheaf's compiler works
Before diving into the optimizations, a brief overview of Sheaf's compilation pipeline.
1. Parser. Sheaf source is transformed into an AST of S-expressions. Unlike a "pure" Lisp such as Scheme, which is its own AST, Sheaf has syntactic sugar (threading macros, destructuring, keyword args), so we need a parser to normalize everything into a uniform tree:
(defn gpt-forward [x params config] ...)
↓
SheafValue::List([Symbol("defn"), Symbol("gpt-forward"), ...])
2. Compiler. The AST is then lowered into CompiledExpr, an internal IR.
This stage expands macros, recognizes special forms (defn, let, if, fn),
resolves symbols, and registers functions in the compiler registry. The result
is a tree of typed enum variants:
enum CompiledExpr {
Integer(i64),
Float(f64),
Symbol(String),
FunctionCall(String, Vec<CompiledExpr>),
Let { bindings, body },
Lambda { params, body },
If { cond, then, else_ },
GetTupleElement { param, indices },
Vector(Vec<CompiledExpr>),
Keyword(String),
...
}
3. Transforms. Several passes rewrite the CompiledExpr tree before
codegen: dictionary accesses become positional tuple lookups, known function
calls are inlined, reduce loops are unrolled into chains of Let bindings
(required for autodiff), and constant subexpressions are evaluated at compile
time.
4. Codegen. Each CompiledExpr node maps to one or more StableHLO
operations. The codegen also computes output shapes for each tensor operation,
threading shape information through the entire graph.
5. Compilation. Finally, iree-compile lowers the StableHLO module into a
VMFB file for a specific backend (CPU, Metal, Vulkan or CUDA).
The marshalling wall
With the profiling tools in place, the first bottleneck became obvious: marshalling overhead. Every time the interpreter calls into IREE, it must convert each tensor argument into an IREE buffer view. For NanoGPT's forward pass, that means 27 parameter tensors, converted and uploaded on every single call.
For 1000 tokens of autoregressive generation, that's 27,000 buffer allocations and roughly 8GB of cumulative memory copies. The GPU was spending more time waiting for data than computing.
The fix was a fingerprint-based buffer view cache. Computing a hash per function
proved very compute-expensive, so each tensor argument is now fingerprinted
using Arc pointer identity: same Arc means same data, which means the
previously created buffer view can be reused. This is O(1) with no false
positives, since Sheaf's values are immutable.
The cache brought inference from 11.3s down to 7.7s for 1000 tokens (a 32% improvement), but a gap remained.
Deep copies in the interpreter
Profiling the remaining gap revealed a surprising cost: the interpreter was deep-cloning entire tensors on every function call and environment lookup. A GPT forward pass with 27 parameter tensors meant copying ~8MB of weights repeatedly, in what should have been zero-cost reference passing.
The new implementation wraps tensor data in Arc<ArrayD<f32>>. Function calls,
let bindings, and dictionary lookups now bump a reference count instead of
copying data. The impact is dramatic:
| Version | 1000 tokens | vs V1 (JAX) |
|---|---|---|
| Before optimizations | 11.3s | 1.7× slower |
| + buffer cache | 7.7s | 1.1× slower |
| + Arc tensors | 4.85s | 1.4× faster |
| JAX V1 (Python + JAX) | 6.8s | baseline |
For BareGPT, my small Transformer, Sheaf V2 now beats the Sheaf V1 Python+JAX implementation, in a 3MB binary vs a 500MB runtime.
The next step is to scale up to GPT-2 124M and benchmark against PyTorch directly.
JIT compilation
The AOT approach (sheaf build) required a manual compilation step and couldn't
handle higher-order functions or closures that capture runtime state. The JIT
compiler replaced it with transparent on-first-call compilation.
When a pure function is called for the first time, the JIT traces its arguments
to infer concrete types (no more defparams!) and automatically runs the full
compiler pipeline. If the function's source changes, the hash changes, and the
JIT recompiles. The iree-compile toolchain is auto-downloaded on first use
(Zig-style), so there's nothing to install.
f32 everywhere
A less obvious optimization: the interpreter was using f64 internally (Rust and ndarray's default float type), while IREE and the GPU work exclusively in f32. Every tensor crossing the interpreter-IREE boundary was being cast, element by element.
Switching Value::Float from f64 to f32 eliminated this conversion entirely.
The impact on inference was a 31% speedup (1080ms to 746ms for 100 tokens), and
training dropped from 1.26s to ~670ms per step.
DeviceBuffer and the fast path
The final optimization addressed what happens between consecutive IREE calls.
When adam-step produces updated parameters on the GPU, those tensors should
flow directly into the next train-step call without ever touching host memory.
DeviceBuffer wraps an IREE buffer view that stays on the GPU. When all inputs
to an IREE call are already DeviceBuffers, a fast path skips the buffer cache
entirely: no mutex lock, no fingerprint check, just a pointer retain and push.
For multi-step training, this means step 1 pays the full marshalling cost (all
tensors are new), but steps 2+ run with near-zero buffer overhead.
IREE dispatch profile (3 calls, 1461.5ms total):
flatten: 0.0ms ( 0.0%)
buffers: 416.7ms (28.5%) [hits: 908, misses: 454]
call: 1044.7ms (71.5%)
output: 0.0ms ( 0.0%)
All 454 misses occur on step 1. Steps 2 and 3 are 100% cache hits.
Where things stand
Here is the full journey on GPT-2 124M training so far (batch=1, block=64, Metal GPU):
| Version | Per step (warm) | vs PyTorch MPS |
|---|---|---|
| Interpreter only | ~11s | 86× |
| + JIT | 1.26s | 10× |
| + f32 interpreter | ~670ms | 5× |
| + DeviceBuffer fast path | 533ms | 4.2× |
| PyTorch (eager, MPS) | 128ms | baseline |
From 11 seconds to 533 milliseconds. The remaining 4.2× gap is now dominated by the GPU kernel itself (71.5% of dispatch time). This is IREE's Metal backend vs Apple's MPS, which is a different compiler targeting the same hardware and is outside Sheaf's control.
For inference, the picture is different. For 100 tokens of autoregressive generation, 94% of the time is spent in the GPU call. Buffer overhead is negligible. The gap with PyTorch (746ms vs 390ms) is entirely in kernel quality.
The optimizations described here will continue in a follow-up article, where I'll cover the autodiff pipeline and training step fusion.