Intro to Jax
Abstract
Mutiple packages came to solve the speed issue of python
in scientific computing. The de facto standard is of course numpy
. Think also numba
, jax
, torch
or even a new language that go beyond what
numpy
offers. This post is a quick intro to my personal preference as of Jan 2025.
Numpy operations
The part of jax
that's straightforward to use for numpy
users is the jax.numpy
module, which has the identical API (almost). Like other parts of jax
, these operations are built on top of the XLA compiler intermediate representations for high performance numerical computations.
There is the distinction between abstract and concete arrays, the former is called ShapedArray
which captures dtype and shape only, the latter called Array
, which carries all the values too.
The caveat of using jax.Array
is that they are immutable. Here is the syntax for "mutation" if desired:
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
y = x.at[1:].set(0)
print(y)
z = x.at[1:].add(1)
print(z)
Random numbers
Be aware that jax
follows the functional programming paradigm. This implies explicit key handling for samplers. The samplers can be composed with vmap
to achieve vectorization
across all parameters, e.g. random.t
has two parameters (key and df), one can supply one/two arrays to the vectorized sampler vmap(random.t, ...)
. By default, in_axes=0
which means vectorizing all leave nodes (arrays) with the leading dimension being the batch dimenesion. See dedicated section below for more.
from jax import random, vmap, numpy as jnp
k = random.key(42)
k, k1 = random.split(k)
out1 = random.normal(k1, (3,))
k, k2 = random.split(k)
dfs = jnp.array([1, 2]) # degree of freedom of two t-distributions
out2 = vmap(random.t, in_axes=(None, 0))(k2, dfs)
print(out1)
print(out2)
Notice that we provide an array dfs
as input because jax
expects array for vectorized arguments. Providing list (pytree!) wouldn't work.
See details in the official docs.
vmap
More control over the operating axes of vmap
is possible. Here in_axes=(0,None)
imposes that the vectorization occurs in the first argument of the function with the batch axis 0. Without specifying in_axes
, the vmap(f)
would expect its arguments to be
arrays of rank (at least) 1 and containing the same number of elements once unpacked, which is not the case for our inputs.
Notice that broadcasting à la numpy is performed for the base function (before vmap). The effect of vmap is np.stack
individual results of the function along the new out_axes (in this example the columns). Using vmap
can avoid manual batching, manual stacking etc.
from jax import vmap, numpy as jnp
def f(x, y): return x + y
xs = jnp.array([0, 1, 2, 3])
y = jnp.array([4, 5])
out = vmap(f, in_axes=(0, None), out_axes=1)(xs, y)
print(out)
jit
if you can
In eager mode, jax
transformations/opertors run sequentially one at a time.
With jit
compilation, jax
program, more precisely the underlying computation graph, is optimized (e.g. rearrange, fuse transformations) by XLA compiler so that it runs faster.
The idea of jit
compilation is to run/compile the program in python once, and cache the compiled program for repetitive evaluations. Because of the overhead of compilation, it would be the best if similar inputs would not trigger re-compilation. To this end, jax
transformations must be agnostic of the values of the inputs, and they must know the shape and dtype of the inputs and outputs to comply wiht the XLA's requirement of being compile-time static.
It may sould like one should jit
everything but this is not always possible. Consider this.
# NOT WORKING!
from jax import jit
@jit
def f(x):
if x > 0: return x
else: return jnp.stack([x,x])
try: f(3)
except Exception as e: print(e)
Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at <code block: n4>:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
This function is not jit
'able. This would error out becasue the value of x
must be known upfront to be able to determine the shape of the output. Of course one can get around the if
statement with jnp.where
(which makes explicit both branches). Conveniently, ALL jax.numpy
operations are jittable.
But consider this
# NOT WORKING!
from jax import jit
@jit
def f(x):
if x > 0: return x
else: return jnp.stack([x,x])
try: f(3)
except Exception as e: print(e)
Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at <code block: n5>:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
This function is not jit
compatible neither. XLA really does not like python while loops because they are fully dynamic. Intermediate states COULD (they don't necessary have to but they could) change dtype and shape depending on the value, so XLA has no hope to analyze it statically.
Two fixes are possible. One is to make it a static control flow by specifying the number of iterations i.e. treating n
as constant at compile time. Effectively the loop is unrolled at compile time.
from jax import jit
from functools import partial
@partial(jit, static_argnums=1)
def g(x, n):
i = 0
while i < n: i += 1
return x + i
print(g(1, 5))
Another is to use (dynamic) structured control flow jax.lax.while_loop
. The number of iterations is allowed to be dynamic (with static dtype and shape of course), but the structure of the contidion and body functions are static.
from jax import jit
from jax.lax import while_loop
def cond_fun(val):
i, n = val
return i < n
def body_fun(val):
i, n = val
return i + 1, n
@jit
def g(x, n):
end, _ = while_loop(cond_fun, body_fun, (0, n))
return x + end
print(g(1, 5))
Details here. More control flow operators: check this page.
grad
grad
!
When you can compose grad and jacobian to get Hessian, you know automatic differentiation of jax
is done right.
Define your function and grad your way out with respect to any variable you are interested. Compose it with jit
for performance. Some obvious caveats:
- functions must be scalar-valued (there is
jax.jacobian
for vector valued funcs) - inputs must be continous (e.g. float)
- functions must be differentiable (indexing, argmax etc are not ok)
Simple profiling/testing
Use .block_until_ready()
on the output jax.Array
of functions to measure the time consumed to execute the function.
There is no testing module in jax
. Use np.testing.assert_allclose
to check the results.
Type hints
Use jax.typing.ArrayLike
for array input and jax.Array
for array output.
Reference
https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods https://jax.readthedocs.io/en/latest/working-with-pytrees.html#example-of-jax-tree-map-with-ml-model-parameters https://jax.readthedocs.io/en/latest/working-with-pytrees.html#custom-pytree-nodes https://jax.readthedocs.io/en/latest/stateful-computations.html#simple-worked-example-linear-regression