Skip to content

Python Interoperability

Sheaf code compiles to pure JAX functions that integrate seamlessly with Python. Sheaf functions can be called from Python, the compilation registry can be accessed, and Sheaf and Python code can be mixed in the same project.


Basic Usage: Inline Code

The simplest way to use Sheaf from Python is to load code directly as a string:

from sheaf import Sheaf

shf = Sheaf()

# Load Sheaf code as a string
shf.load("""
  (defn add-five [x]
    (+ x 5))

  (defn square [x]
    (* x x))
""")

# Call Sheaf functions as Python methods
result = shf.add_five(10)      # => 15
result = shf.square(3)         # => 9

# Works with arrays too
import jax.numpy as jnp
arr = jnp.array([1, 2, 3])
result = shf.add_five(arr)     # => Array([6., 7., 8.])

Python Naming Convention

Sheaf function names use kebab-case (hyphens), but Python methods use underscores. Sheaf automatically converts:

(defn train-step [params grad]
  ...)

(defn compute-loss [x y]
  ...)
# Access with underscores instead of hyphens
shf.train_step(params, grad)
shf.compute_loss(x, y)

File-Based Code

For larger projects, Sheaf code can be loaded from files:

Loading a Single File

from sheaf import Sheaf

# Load from a relative path (relative to calling script)
shf = Sheaf("model.shf")

# Or load after creation
shf = Sheaf()
shf.load_from_path("path/to/model.shf")

Relative Path Resolution

Paths are resolved relative to the calling script's directory, not the current working directory:

# In /project/train/runner.py
shf = Sheaf("../models/nn.shf")  # Resolves to /project/models/nn.shf

# Works correctly regardless of where the script is run from:
# $ cd /project && python train/runner.py
# $ cd /project/train && python runner.py

Accessing the Registry

The registry contains all compiled functions. It can be introspected from Python:

Getting Registry Metadata

from sheaf import Sheaf

shf = Sheaf()
shf.load("""
  (defn forward [x params]
    (+ (@ x (:W params)) (:b params)))

  (defn loss [pred target]
    (mean (* (- pred target) (- pred target))))
""")

# Get metadata about all functions
registry = shf.get_registry()

print(registry)
# Output:
# {
#   'forward': {'params': ['x', 'params'], 'source': '(defn forward ...'},
#   'loss': {'params': ['pred', 'target'], 'source': '(defn loss ...'}
# }

# Check what functions are available
if 'forward' in shf.registry:
    forward_fn = shf.registry['forward']
    # forward_fn can now be called with appropriate arguments

Common Registry Tasks

# List all loaded functions
print(list(shf.registry.keys()))

# Get a function by name
fn = shf.registry['my-function']
result = fn(arg1, arg2)

# Get function parameters
meta = shf.get_registry()
params = meta['my-function']['params']
print(f"Parameters: {params}")

# Call with kwargs (Python standard)
fn = shf.registry['my-function']
result = fn(x=value1, y=value2)

Accessing the Environment

The environment (env) contains all variables and intermediate values defined or computed during execution:

Viewing Environment Contents

from sheaf import Sheaf

# List what's available in the environment
print(list(shf.env.keys()))  # Shows built-in functions

shf = Sheaf()
shf.load("""
  (defn create-model [key]
    {:W (random-normal key '[10 5])
     :b (zeros '[5])})

  (defn get-config []
    {:learning-rate 0.001
     :batch-size 32})
""")

# Call functions
model = shf.create_model(shf.env['random-key'](0))
config = shf.get_config()

# Get all user-defined functions with metadata
registry = shf.get_registry()

print(registry.keys())
# Output: dict_keys(['create-model', 'get-config'])

# Access config values from returned dictionary
lr = config['learning-rate']
print(f"Learning rate: {lr}")

Working with Tensors

shf.load("""
  (defn create-weights [key]
    {:W (random-normal key '[3 4])
     :b (zeros '[4])})
""")

# Create weights
weights = shf.create_weights(shf.env['random-key'](0))

# Access tensor values
W = weights['W']
print(f"Shape: {W.shape}")      # (3, 4)
print(f"Dtype: {W.dtype}")      # float32

# Modify tensors using JAX
import jax.numpy as jnp
new_W = W * 0.5
print(new_W)  # JAX Array

PyTree Conversion

Sheaf provides conversion utilities for working with JAX pytrees:

Converting Sheaf Values to PyTrees

from sheaf import Sheaf
import jax

shf = Sheaf()
shf.load("""
  (defn create-params [key]
    (let [[k1 k2] (random-split key 2)]
      {:layer1 {:W (random-normal k1 '[3 4]) :b (zeros '[4])}
       :layer2 {:W (random-normal k2 '[4 1]) :b (zeros '[1])}}))
""")

params = shf.create_params(shf.env['random-key'](42))

# Convert to pure JAX pytree
pytree = shf.to_pytree(params)

# Now use with JAX transformations
def loss_fn(pytree_params):
    # params is a pure pytree here
    return jax.numpy.sum(pytree_params['layer1']['W'])

loss = loss_fn(pytree)

Converting PyTrees Back to Sheaf

# Apply JAX transformation to pytree (e.g., scale all weights)
import jax.tree_util as tree_util
pytree_result = tree_util.tree_map(lambda x: x * 0.1, pytree)

# Convert back to Sheaf-compatible structure
sheaf_result = shf.from_pytree(pytree_result)

print(sheaf_result['layer1']['W'].shape)  # (3, 4)

Introspecting Language Features

Available Special Forms

# See all special forms (defn, let, vmap, scan, reduce, etc.)
special_forms = shf.get_special_forms()
print(special_forms)
# ['and', 'assoc', 'def', 'defn', 'dissoc', 'let', 'reduce', 'scan', 'vmap', ...]

Function Inspection

# Load neural network library
shf.load("(use nn)")

# Get detailed info about a function
registry = shf.get_registry()

if 'layer-norm' in registry:
    info = registry['layer-norm']
    print(f"Parameters: {info['params']}")  # ['x', 'p', 'axis']
    print(f"Source:\n{info['source']}")

Mixing Python and Sheaf Code

Data preprocessing can be done in Python, then passed to Sheaf for computation:

from sheaf import Sheaf
import jax.numpy as jnp
import numpy as np

shf = Sheaf()
shf.load("""
  (defn forward [x params]
    (+ (@ x (get params :W)) (get params :b)))
""")

# Data preprocessing in Python
raw_data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
normalized = (raw_data - raw_data.mean()) / raw_data.std()

# Computation in Sheaf
params = {'W': jnp.array([[1, 2], [3, 4]]), 'b': jnp.array([0, 0])}
output = shf.forward(normalized, params)

print(output)  # JAX array result

Error Handling

Sheaf provides detailed error messages that include line numbers and context:

from sheaf import Sheaf

shf = Sheaf()

try:
    shf.load("""
      (defn broken [x]
        (undefined-function x))
    """)
    result = shf.broken(5)
except Exception as e:
    print(f"Error: {e}")
    # Error: Symbol not found: 'undefined-function' (line 3)

Type Safety and JAX Integration

Sheaf functions are pure JAX functions. They:

  • Compile to JAX bytecode for performance
  • Support automatic differentiation via jax.grad and jax.value_and_grad
  • Compose with JAX transformations like jax.jit, jax.vmap, jax.scan
from sheaf import Sheaf
import jax

shf = Sheaf()
shf.load("""
  (defn model [x params]
    (sigmoid (+ (@ x (:W params)) (:b params))))
""")

# Extract the JAX function
model_fn = shf.registry['model']

# Use JAX transformations
grad_fn = jax.grad(model_fn)
jitted_fn = jax.jit(model_fn)

# Compose them
fast_grad = jax.jit(jax.grad(model_fn))

Summary

Operation Code Notes
Inline code shf.load("(defn ...") Simple, interactive
File-based shf = Sheaf("model.shf") Better for larger projects
Call function shf.function_name(args) Python naming (underscores)
Get registry shf.get_registry() Metadata about functions
Get environment shf.get_env() Metadata about variables
Access function shf.registry['name'] Direct JAX function
Access variable shf.env['name'] Direct value
Convert to pytree shf.to_pytree(value) For JAX operations
Convert from pytree shf.from_pytree(pytree) Reconstruct Sheaf value
Introspect forms shf.get_special_forms() Available language features

For detailed function signatures and examples, see the Function Reference.