Starting with Sheaf
The Functional Mindset
Sheaf is a functional language rather than an imperative one. An imperative language like Python instructs the computer to change its internal state step-by-step. In Sheaf, transformations are defined by how data flows through nested functions to produce a result.
Functional programming avoids side effects, which makes code deterministic. It maps naturally to computational graphs and machine learning workloads.
A first example
A functional language can appear confusing at first, because even a classic val = 2 ; print val example does not translate directly. Since "everything is a function", we would write a function (defn val [] 2) and then call it: (val). Or, alternatively, evaluate both immediately: ((defn val [] 2))
In Sheaf, as in all Lisp dialects, the fundamental unit of computation is the expression. There is no structural difference between a mathematical operator, a control flow structure, or a complex neural layer: they all follow the same notation.
For this reason, a global "variable" is effectively a nullary function (a function with no arguments) that returns a constant.
To define a function, we use the defn form, followed by the function name and its arguments enclosed in square brackets [] rather than parentheses (). This is borrowed from Clojure to differentiate the parameter list from the function calls.
Let's see what this example would look like in Sheaf's REPL:
$ sheaf
Welcome to Sheaf Console v1.0.0
Type :help or :h for help, :quit or :q to exit
sheaf> ;; "Define a function `val` which takes no args and returns 2":
sheaf> (defn val [] 2)
⇒ <function DefnForm.compile.<locals>.generated_func at 0x10acf8ae0>
sheaf> (val) ; Call that function
⇒ 2
Now let's try adding 2 to it. Sheaf uses the prefix notation, which we'll talk about later. For now, we'll just write (+ 2 val) instead of (2 + val).
sheaf> (+ 2 val) ; This actually won't work: `val` is a function, not a scalar
error: add requires ndarray or scalar arguments, got <class 'function'> at position 1.
--> <repl>:1
|
1 | (+ 2 val)
| ^
|
= note: Check the number of arguments you're passing to the function.
sheaf> (+ 2 (val)) ; We use parens to call `val` as a function
⇒ Tensor i32[] = 4 ; Now it works!
sheaf>
sheaf> (defn val [] 4) ; Let's try to redefine our function to 4
Error:
Function 'val' is already defined in user code. Redefinition is not allowed to prevent shadowing bugs.
The inability to change the value of val may come as surprise.
In Sheaf, every value is immutable by default: once a function is defined, it becomes stateless and cannot be changed. This is required both for predictability and performance reasons in the JIT engine.
Expressions and Lambdas
Remember that in Lisp, everything is an expression.
When an expression like (+ 1 2) is evaluated, the interpreter does not simply perform an addition. It evaluates a form where the symbol + refers to a sum function, and 1 2 are its arguments, much like np.add(1,2) in NumPy.
Conceptually, a Sheaf program is not a sequence of instructions, but a single, nested tree of expressions.
While (+ 1 2) will execute immediately, it can also be explicitly defined as a lambda with (fn [] (+ 1 2)).
When defined with fn, the logic won't be executed immediately but will return an anonymous function that can be called from within the parent function.
This is used when a small transformation is needed just once, without giving the function a permanent name.
Let's see this in action:
sheaf> (* 5 10)
⇒ Tensor i32[] = 50
sheaf> ;; Now let's use a variable
sheaf> (* x 10) ;; Won't work: `x` is not defined in the current scope
error: Symbol not found (line 1): 'x'
--> <repl>:1
|
1 | (+ x 10)
| ^
|
= note: Check for typos in function or variable names.
sheaf> ;; Instead, define it as a lambda and pass it the value `5`
sheaf> ((fn [x] (* x 10)) 5)
⇒ Tensor i32[] = 50
sheaf> ;; Notice the double (( )) to call the lambda immediately
sheaf> ;; Without them, this defines the function but does not apply it
sheaf> (fn [x] (* x 10) 5)
⇒ <function LambdaForm.compile.<locals>.anonymous_func at 0x10e4949a0>
Local bindings
Defining a function to return a scalar value isn't very helpful. Usually, Sheaf functions apply transformations to tensors and pass them to other functions.
Sheaf doesn't assign variables; it binds local symbols with the form let. Variables exist within the function's scope and disappear after the evaluation is complete.
Let's try binding values:
sheaf> ;; Binding a tensor [1 2 3] to x, then returning x
sheaf> (let [x [1 2 3]]
... x)
⇒ Tensor f32[3] = [1. 2. 3.]
sheaf> ;; Automatic broadcasting: adding 2 to the tensor
sheaf> (let (x [1 2 3])
... (+ x 2))
⇒ Tensor f32[3] = [3. 4. 5.]
Because of immutability, x is also not modified in the second example. Instead, a new tensor is created and returned. Once the command ends, it is no longer accessible.
Prefix Notation (Polish Notation)
Another confusing aspect in our example may be the operation order: (+ x 2) instead of (x + 2).
Sheaf, like Clojure, Scheme and other Lisps, use prefix notation, where the operator always precedes its operands.
This approach serves two critical purposes:
-
Variadic Operations: Functions can take any number of arguments without repeating the operator. Instead of
1 + 2 + 3 + 4, one simply writes(+ 1 2 3 4). -
Consistency: Since everything between parens is a function, mathematical operators, built-in functions, and user-defined functions all follow the exact same structure:
(function argument1 argument2 ...). This uniformity is what allows the language to treat code as data.
To make this simpler, let's just remember that the first symbol after a parenthesis is always a function.
Conditionals: if and where
Most programming languages have an if statement for branching logic. Sheaf has two distinct forms for conditional evaluation, and understanding when to use each is important.
The if form works like a traditional conditional: it evaluates a condition, then returns either the "then" branch or the "else" branch:
sheaf> (if (> 5 3) :yes :no)
⇒ :yes
sheaf> (if false "not returned" "returned")
⇒ 'returned'
This works well for decisions based on configuration values or scalars. However, if cannot be used with tensor values inside JIT-compiled functions. JAX needs to know the control flow at compile time, and tensor values are only known at runtime.
sheaf> (defn broken [x]
... (if (> x 0) x 0)) ;; This will fail in JIT context
For element-wise conditionals on tensors, where is used instead. It evaluates both branches and selects values based on a boolean mask:
sheaf> (where (> 5 3) 100 0)
⇒ Tensor i32[] = 100
sheaf> ;; Element-wise on tensors:
sheaf> (where (> [1 5 3 8] 4) [1 5 3 8] 0)
⇒ Tensor i32[4] = [0 5 0 8]
The last example replaces all values less than or equal to 4 with zero, keeping the others. This is the functional equivalent of NumPy's np.where.
Let's see another example. Common pattern is implementing ReLU (Rectified Linear Unit), which returns x if positive, else 0:
sheaf> (defn my-relu [x]
... (where (> x 0) x 0))
sheaf> (my-relu [-2 -1 0 1 2])
⇒ Tensor i32[5] = [0 0 0 1 2]
In summary: use if for static branching (configs, flags), and where for tensor operations.
Recursion
When reading Sheaf code, the lack of imperative loops (for and while loops) may come as another surprise.
In Sheaf, iteration is handled through recursion rather than imperative loops. Since data is immutable, a process is not described by updating a counter, but by repeatedly applying a transformation to a value.
Let's take a look at a typical loop in Python:
# Imperative loop
>>> total = 0
>>> for i in range(10):
>>> total += i
>>> print(total)
45
# Vectorized imperative loop
>>> import numpy as np
>>> np.sum(range(10))
np.int64(45)
As a functional language, Sheaf typically uses functions over imperative loops to achieve this. Such iterations are generally handled by higher-order abstractions like reduce:
sheaf> (reduce + 0 (range 10))
⇒ Tensor i32[] = 45
reduce takes a function (here, +), a starting number (0), and a sequence on which to apply the function. For instance, (reduce + 0 [1 2 3 4]) will result in adding 0 + 1 + 2 + 3 + 4.
Other popular higher order functions are map, vmap and scan.
Unlike reduce which returns a single combined value, map applies a transformation to every element of an array and returns another array.
For instance, let's square all values from an array of integers ranging from 1 to 10. We will generate a sequence using range(1 11), then use map with a lambda function to square the results.
sheaf> (range 1 11)
⇒ Tensor i32[10] = [ 1 2 3 4 5 6 7 8 9 10]
sheaf> (map (fn [i] (* i i)) (range 1 11))
⇒ '[Tensor i32[] = 1, Tensor i32[] = 4, Tensor i32[] = 9, Tensor i32[] = 16, Tensor i32[] = 25, Tensor i32[] = 36, Tensor i32[] = 49, Tensor i32[] = 64, Tensor i32[] = 81, Tensor i32[] = 100]
Note: Sheaf is "tensor-first" by design and only returns arrays, even for scalar values. This ensures any operation is compilable into an XLA kernel running on a GPU or TPU.
vmap is very similar to map, but will vectorize the operation over the entire array:
sheaf> ((vmap (fn [i] (* i i)) 0) (range 1 11))
⇒ Tensor i32[10] = [ 1 4 9 16 25 36 49 64 81 100]
Finally, scan is a specialized iteration primitive that carries a state across a sequence (like reduce) but also returns all intermediate results.
In differentiable programming, scan is the preferred way to implement recurrent structures or deep loops, as it allows the XLA compiler to optimize the backpropagation through time more efficiently than raw recursion.
Working with Parameters
In neural networks, parameters (weights and biases) are typically stored in nested dictionaries. Sheaf provides several tools to work with these structures elegantly.
The get function retrieves a value from a dictionary by key:
sheaf> (get {:learning-rate 0.001 :epochs 10} :learning-rate)
⇒ 0.001
For deeply nested structures, get-in navigates through multiple levels using a path vector:
sheaf> (get-in {:model {:layer1 {:W [[1 2] [3 4]] :b [0.1 0.2]}}} [:model :layer1 :b])
⇒ Tensor f32[2] = [0.1 0.2]
A default value can be provided if the path does not exist:
sheaf> (get-in {:a 1} [:missing :path] 42)
⇒ 42
When writing neural network layers, repeatedly calling get becomes tedious. The with-params form unpacks dictionary keys directly into the local scope. This is called parameters destructuring:
sheaf> (let [params {:W [[1 2] [3 4]] :b [0.5 0.5]}
... x [1.0 1.0]]
... (with-params [params]
... (+ (@ x W) b)))
⇒ Tensor f32[2] = [4.5 6.5]
Without with-params, the equivalent code would require explicit get calls:
;; More verbose alternative
(let [params {:W [[1 2] [3 4]] :b [0.5 0.5]}
x [1.0 1.0]]
(+ (@ x (get params :W)) (get params :b)))
For nested parameter structures, a key can be specified to destructure a sub-dictionary:
sheaf> (let [model {:layer1 {:W [[1 0] [0 1]] :b [0 0]}
... :layer2 {:W [[2 2] [2 2]] :b [1 1]}}
... x [1.0 2.0]]
... (with-params [model :layer1]
... (+ (@ x W) b)))
⇒ Tensor f32[2] = [1. 2.]
Since all values are immutable, dictionaries are updated by creating new ones. The assoc form adds or updates keys:
sheaf> (assoc {:a 1} :b 2 :c 3)
⇒ {:a 1, :b 2, :c 3}
The merge function combines multiple dictionaries (later values override earlier ones):
sheaf> (merge {:lr 0.001 :epochs 10} {:lr 0.0001})
⇒ {:lr 0.0001, :epochs 10}
And dissoc removes keys:
sheaf> (dissoc {:a 1 :b 2 :c 3} [:a :c])
⇒ {:b 2}
These operations are essential for managing optimizer state, where parameters, momentum, and velocity are all stored in nested dictionaries that must be updated at each training step.
The threading operators
In older Lisps such as Scheme, Common Lisp, or Emacs Lisp, nesting functions can sometimes become a "parentheses nightmare." One of the great innovations of Clojure is the threading operators -> and as->.
These operators "thread" a value through a sequence of functions, making the code read like a pipeline from top to bottom.
-> (Thread-first) automatically injects the result of the previous expression as the first argument of the next function call. It is ideal for linear transformations.
as-> (Thread-as) binds the result to a specific symbol (a variable, in imperative talk), allowing us to place it anywhere in the next expressions. This is more typically used for operations where the input must be reused.
Let's compare them to the traditional nesting:
sheaf> ;; Traditionnal nesting - dense, but hard to read
sheaf> ((fn [x] (* x x)) (* (+ 10 5) 2))
⇒ Tensor i32[] = 900
sheaf> ;; Threaded nesting
sheaf> (as-> 10 x ;; Send 10 to x
... (+ x 5) ;; Add 5 to it
... (* x 2) ;; Multiply it by two
... ((fn [n] (* n n)) x)) ;; Square it with a lambda
⇒ Tensor i32[] = 900
In neural networks such as BareGPT in the examples folder, both threading operators are used to describe the flow of a tensor through layers. It makes the forward function look like an actual architecture diagram rather than a "soup" of parentheses:
(defn transformer-block (x layer-p config)
;; as-> binds the initial 'x' to the name 'h'
;; This allows us to reuse 'h' multiple times within the block,
;; specifically for the residual connection at the end
(as-> x h
;; We now switch to '->' because each function takes the output of the
;; previous one as its first argument, somewhat like with a shell pipe
(-> h
(local-layer-norm (get layer-p :ln1))
(multi-head-attention layer-p config)
(first) ;; Get the attention output, ignore weights
(+ h)) ;; Residual 1: we explicitly reference 'h' named above
;; We re-bind the result of the previous block to 'h'
(as-> h
(-> h
(local-layer-norm (get layer-p :ln2))
(mlp (get layer-p :mlp))
(+ h))))) ;; Residual 2
Some quick exercices
Let's try some quick exercises to become more familiar with Sheaf.
Exercise 1:
- What is the result of
(1 + 2) * 3? - Create an array of integers ranging from 1 to 5
- Get its mean value
- Get its shape
Hint: The Sheaf REPL has autocompletion. Press [Tab] twice to see all the commands.
Exercise 2:
- Write a
squarefunction using the formdefn, and call it with various numbers - Write a lambda that returns its argument cubed (^3), using
fn - How could we make this work on an array instead of a scalar?
Exercise 3:
- We need to perform the following operations on input
x: add 5, multiply by 2, square - Implement it with a function that uses the threading operator
as->to lower the amount of nested parentheses - Write a function that sums all elements of an array, without using
sum.
Hint: To see the documentation and signature of a form, use :help <form>.
Glossary
Sheaf, like Clojure and other Lisp dialects, use its own terminology. This list summarizes the "technical jargon" for developers used to imperative languages like Python.
- Application: A function call (applying arguments to a function).
- Binding: The association of a value to a symbol within a specific scope (assigning a variable in imperative talk).
- Expression: Any piece of code that, when evaluated, produces a value.
- Form: Any syntactically valid piece of code (a symbol, a literal, or a list).
- Lambda (
fn): An anonymous function. - S-Expression: The core of Lisp: a symbolic expression inside parentheses, and also the tree structure that represents both code and data.
- Symbol: A name (like
xormy-func) that refers to a value or a function.