Skip to content

Intro to JAX

Abstract

Mutiple packages came to solve the speed issue of python in scientific computing. The de facto standard is of course numpy. Think also numba, jax, torch or even a new language 🔥 that go beyond what numpy offers. This post is a quick intro to my personal preference as of Jan 2025.

PyTree

At the center of JAX is undoubtedly jax.Array. A nested container of these objects is called a PyTree, e.g. lists of dicts of Array. The Array's are the leaves of the PyTree. It is possible to register a custom Pytree class.

Numpy operations

jax.numpy module is easy to use for numpy users because they share near identical API, intentionally. Operations in jax.numpy are built on top of the XLA compiler primitives for high performance numerical computations.

There is the distinction between abstract structure and concete value of arrays in jax. A ShapeDtypeStruct object captures dtype and shape, whereas Array carries the concrete values too.

The caveat of using jax.Array is that they are immutable. The "mutation" syntax is different from numpy and actually generates new arrays instead of mutate them in place, in constrast to numpy.

import jax.numpy as jnp

x = jnp.array([1, 2, 3])

y = x.at[1:].set(0)
print(y)

z = x.at[1:].add(1)
print(z)
[1 0 0]
[1 3 4]

Broadcasting is arguably an elegant design of the numpy array API. It works the same for jax arrays.

vmap

vmap vectorizes a function by adding a batch dimension at the leading axis to all the leave nodes of a function's arguments.
In other words, vmap is like np.stack individual results of a sequence of function calls. We use vmap to avoid manual batching, manual stacking etc.

Since the arguments of the function at hand can be an arbitrary PyTree, the axis 0 will be the batch dimension for all arrays in each PyTree, and the shape of the supplied arrays embedded in the PyTrees should be compatible with that.

More control over the operating axes of vmap is possible. Below in_axes=(None,0) adds batch dimension at the leading axis for the second positional argument of the function only. If we do not specify in_axes (defaults to 0 or equivalently (0,0)), the vectorized function vmap(f) would add batch dimension to all leave nodes of the input PyTree's. This behaviour is not compatible with the inputs in our example, because beta.shape is (5,) and x.shape is (2,5), so their leading dimensions are not the same.

import jax
import jax.numpy as jnp

def li(beta, x):
    """linear predictor with multiple regressors"""
    return beta.dot(x)

beta = jnp.arange(5)

jax.vmap(li, in_axes=(None,0))(
    beta, 
    jr.normal(jr.key(1), (2,5))
)

Additionally, we can specify out_axes along which we stack the results, to use our np.stack analogy again. But in our example above since the output is 1D array, we can specify nothing other than out_axes=0, which is the default already.

More details about vmap here

Random numbers

Be aware that JAX follows the functional programming paradigm. This implies explicit key handling for samplers. Samplers can be composed with vmap to achieve vectorization across all parameters.

For example, random.t has two parameters key and df, one can supply one array of keys to generate a collection of t distributed random variables with the same degree of freedom like so

jax.vmap(random.t, in_axes=(0,None))(keys:Array, df=2)

or generate a collection each with a different degree of freedom like so

# identical
jax.vmap(random.t)(keys:jax.Array, df:jax.Array)
jax.vmap(random.t, in_axes=0)(keys:jax.Array, df:jax.Array)
jax.vmap(random.t, in_axes=(0,0))(keys:jax.Array, df:jax.Array)

Here is an example of working with PRNG keys. The warning from the official docs is NEVER reuse a key. Split them.

import jax
import jax.random as jr
import jax.numpy as jnp

key = jr.key(42)
out1 = jr.normal(key, (10,))
out2 = jax.vmap(jr.normal)(jr.split(key, 10)) # vmap adds batch dimension at leading axis by default
out3 = jax.vmap(jr.t)(
    key = jr.split(jr.key(21),10),
    df = jnp.arange(10)+1 # degree of freedom of 10 t-distributions
)

print(out1)
print(out2)
print(out3)
[-0.02830462  0.46713185  0.29570296  0.15354592 -0.12403282  0.21692315
 -1.440879    0.7558599   0.52140963  0.9101704 ]
[ 0.07592554  0.60576403  0.4323065  -0.2818947   0.6549178  -0.2166012
 -0.25440374  0.2886397   0.14384735 -1.3462586 ]
[ 0.6342837   0.6981538   1.959329    0.35705897  0.95073795 -2.2646627
  0.93203527  0.50947154  1.1138752  -0.03552625]

Again we provide df of type Array because JAX expects Array for vectorized arguments. Providing list (pytree!) of the same size wouldn't work.

See details in the official docs.

jit if you can

In eager mode, JAX transformations/opertors run sequentially one at a time. With jit compilation, the computation graph of a jax program is optimized (e.g. rearrange, fuse transformations) by XLA compiler so that it runs faster.

The idea of jit compilation is to run/compile the program in python once, and cache the compiled program for repetitive evaluations. Compilation introduces overhead. So what we want is that inputs of the same dtype and shape would not trigger re-compilation. jax achieves this by tracing the dtype and shape of all operands in a computational graph, and optimize this abstract structure without having to know the exact values.

The approach has some implications. JAX transformations such as vmap must be agnostic of the values of the inputs, and they must know the shape and dtype of the inputs and outputs to comply with the XLA's requirement of being compile-time static.

One should jit everything for speed but this is not always possible. Consider this.

# NOT WORKING!
from jax import jit

@jit
def f(x):
    if x > 0:
        return x
    else:
        return jnp.stack([x,x])

try: f(3)
except Exception as e: print(e)

This function is not jit'able because at compile time the compiler cannot get the shape of the ouput without knowking the concete value of x. Of course one can get around the if statement with jnp.where which create branches in the computation graph. Conveniently, ALL jax.numpy operations are jittable.

More broadly, none of the python control flows is jit compatible. Consider this function

def xfibonacci(x, n):

    a, b = 0, 1
    for _ in range(1, n):
        a, b = b, a + b

    return x*b

We cannot jit compile this function because the number of iterations cannot be deduced from the shape and dtype of the input n.
Jax needs to know the exact number of operations to include in the computational graph before compilation.

There are two possible fix. First, we can use static_argnums to make a positional argument static at compile time i.e. treating n as constant. Effectively the loop is unrolled at compile time.

from jax import jit
from functools import partial

@partial(jit,static_argnums=1)
def xfibonacci(x, n):

    a, b = 0, 1
    for _ in range(1, n):
        a, b = b, a + b

    return x*b

Or we can use a JAX control flow primitive.

@jax.jit
def xfibonnacci(x, n):

    def body_fun(i,val):
        a,b = val
        return b, a+b

    _, out  = fori_loop(1, n, body_fun, (0,1)) 

    return x*out

For while loops, the number of iterations can be dependent of the input values, but we can still use the JAX while loop primitive. Consider Newton's algorithm for the square root:

def newton(N, tol=1e-6):
    x = N/2.0
    error = tol + 1.
    while error>tol:
        x_next = 0.5*(x+N/x)
        error, x = abs(x - x_next), x_next
    return x

We can rewrite this function using JAX control flow primitive jax.lax.while_loop. We need to supply the primitive with condition function and body function.

import jax
import jax.numpy as jnp
from jax.lax import while_loop

@jax.jit
def newton_sqrt(N, tol=1e-6):

    def cond_fun(val):
        x, e = val
        return e > tol

    def body_fun(val):
        x, _ = val
        x_next = 0.5 * (x + N / x)
        e = abs(x - x_next)
        return x_next, e

    init = N / 2.0, tol + 1.0 # Ensure the loop runs at least once

    final_estimate, _ = while_loop(cond_fun, body_fun, init)
    return final_estimate

Details here. More on control flow operators: check this page.

grad grad!

When you can compose grad and jacobian to get Hessian, you know automatic differentiation of jax is done right. Define your function and grad your way out with respect to any variable you are interested. Compose it with jit for performance. Some obvious caveats:

  • functions must be scalar-valued (there is jax.jacobian for vector valued funcs)
  • inputs must be continous (e.g. float)
  • functions must be differentiable (indexing, argmax etc are not ok)

Simple profiling/testing

The common pitfall of profiling jax is the unawareness of the async dispatch feature. This is the feature that allows JAX programs to run subsequent code before some heavy computation finishes.

For instance, below we have a heavy matrix multiplication to run first, then a light query of array shape. In reality we first see the shape printed, then the result. The reason is

  1. jax does not wait for x.dot(x) to finish before executing the print statment;
  2. even though out is not completed, its shape and dtype are known and passed to the print statement.
import jax.random as jr

x = jr.normal(jr.key(1), (10000,10000))

def fun(x):
    out = x.dot(x)
    print(out.shape, out.dtype)
    return out

fun(x)

Therefore, we should block main process until the computation before measuring time like so

fun(x).block_until_ready()

There is no testing module in jax. Use np.testing.assert_allclose to check the results.

Type hints

Use jax.typing.ArrayLike for array input and jax.Array for array output.

Reference