Skip to content

Intro to Jax

Abstract

Mutiple packages came to solve the speed issue of python in scientific computing. The de facto standard is of course numpy. Think also numba, jax, torch or even a new language 🔥 that go beyond what numpy offers. This post is a quick intro to my personal preference as of Jan 2025.

Numpy operations

The part of jax that's straightforward to use for numpy users is the jax.numpy module, which has the identical API (almost). Like other parts of jax, these operations are built on top of the XLA compiler intermediate representations for high performance numerical computations.

There is the distinction between abstract and concete arrays, the former is called ShapedArray which captures dtype and shape only, the latter called Array, which carries all the values too.

The caveat of using jax.Array is that they are immutable. Here is the syntax for "mutation" if desired:

import jax.numpy as jnp

x = jnp.array([1, 2, 3])
y = x.at[1:].set(0)
print(y)

z = x.at[1:].add(1)
print(z)
[1 0 0]
[1 3 4]

Random numbers

Be aware that jax follows the functional programming paradigm. This implies explicit key handling for samplers. The samplers can be composed with vmap to achieve vectorization across all parameters, e.g. random.t has two parameters (key and df), one can supply one/two arrays to the vectorized sampler vmap(random.t, ...). By default, in_axes=0 which means vectorizing all leave nodes (arrays) with the leading dimension being the batch dimenesion. See dedicated section below for more.

from jax import random, vmap, numpy as jnp

k = random.key(42)
k, k1 = random.split(k)
out1 = random.normal(k1, (3,))

k, k2 = random.split(k)
dfs = jnp.array([1, 2]) # degree of freedom of two t-distributions
out2 = vmap(random.t, in_axes=(None, 0))(k2, dfs)

print(out1)
print(out2)
[ 0.60576403  0.7990441  -0.908927  ]
[-1.8834321  -0.83157206]

Notice that we provide an array dfs as input because jax expects array for vectorized arguments. Providing list (pytree!) wouldn't work.

See details in the official docs.

vmap

More control over the operating axes of vmap is possible. Here in_axes=(0,None) imposes that the vectorization occurs in the first argument of the function with the batch axis 0. Without specifying in_axes, the vmap(f) would expect its arguments to be arrays of rank (at least) 1 and containing the same number of elements once unpacked, which is not the case for our inputs.

Notice that broadcasting à la numpy is performed for the base function (before vmap). The effect of vmap is np.stack individual results of the function along the new out_axes (in this example the columns). Using vmap can avoid manual batching, manual stacking etc.

from jax import vmap, numpy as jnp

def f(x, y): return x + y

xs = jnp.array([0, 1, 2, 3])
y = jnp.array([4, 5])
out = vmap(f, in_axes=(0, None), out_axes=1)(xs, y)
print(out)
[[4 5 6 7]
 [5 6 7 8]]
More details here

jit if you can

In eager mode, jax transformations/opertors run sequentially one at a time. With jit compilation, jax program, more precisely the underlying computation graph, is optimized (e.g. rearrange, fuse transformations) by XLA compiler so that it runs faster.

The idea of jit compilation is to run/compile the program in python once, and cache the compiled program for repetitive evaluations. Because of the overhead of compilation, it would be the best if similar inputs would not trigger re-compilation. To this end, jax transformations must be agnostic of the values of the inputs, and they must know the shape and dtype of the inputs and outputs to comply wiht the XLA's requirement of being compile-time static.

It may sould like one should jit everything but this is not always possible. Consider this.

# NOT WORKING!
from jax import jit

@jit
def f(x):
  if x > 0: return x
  else: return jnp.stack([x,x])

try: f(3)
except Exception as e: print(e)
Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at <code block: n4>:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

This function is not jit'able. This would error out becasue the value of x must be known upfront to be able to determine the shape of the output. Of course one can get around the if statement with jnp.where (which makes explicit both branches). Conveniently, ALL jax.numpy operations are jittable. But consider this

# NOT WORKING!
from jax import jit

@jit
def f(x):
  if x > 0: return x
  else: return jnp.stack([x,x])

try: f(3)
except Exception as e: print(e)
Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at <code block: n5>:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

This function is not jit compatible neither. XLA really does not like python while loops because they are fully dynamic. Intermediate states COULD (they don't necessary have to but they could) change dtype and shape depending on the value, so XLA has no hope to analyze it statically.

Two fixes are possible. One is to make it a static control flow by specifying the number of iterations i.e. treating n as constant at compile time. Effectively the loop is unrolled at compile time.

from jax import jit
from functools import partial

@partial(jit, static_argnums=1)
def g(x, n):
    i = 0
    while i < n: i += 1
    return x + i

print(g(1, 5))
6

Another is to use (dynamic) structured control flow jax.lax.while_loop. The number of iterations is allowed to be dynamic (with static dtype and shape of course), but the structure of the contidion and body functions are static.

from jax import jit
from jax.lax import while_loop

def cond_fun(val):
    i, n = val
    return i < n

def body_fun(val):
    i, n = val
    return i + 1, n

@jit
def g(x, n):
    end, _ = while_loop(cond_fun, body_fun, (0, n))
    return x + end

print(g(1, 5))
6

Details here. More control flow operators: check this page.

grad grad!

When you can compose grad and jacobian to get Hessian, you know automatic differentiation of jax is done right. Define your function and grad your way out with respect to any variable you are interested. Compose it with jit for performance. Some obvious caveats:

  • functions must be scalar-valued (there is jax.jacobian for vector valued funcs)
  • inputs must be continous (e.g. float)
  • functions must be differentiable (indexing, argmax etc are not ok)

Simple profiling/testing

Use .block_until_ready() on the output jax.Array of functions to measure the time consumed to execute the function. There is no testing module in jax. Use np.testing.assert_allclose to check the results.

Type hints

Use jax.typing.ArrayLike for array input and jax.Array for array output.

Reference