1. An Introduction to JAX#

JAX is an open-source library developed by Google Research, aimed at numerical computing and machine learning. It provides composable transformations of Python functions, enabling automatic differentiation, efficient execution on accelerators like GPUs and TPUs, and seamless interoperability with NumPy. With its functional programming model and powerful features, JAX offers a versatile platform for high-performance computing tasks.

1.1. JAX as a NumPy Replacement#

One way to use JAX is as a plug-in NumPy replacement. Let’s look at the similarities and differences.

1.1.1. Similarities#

The following import is standard, replacing import numpy as np:

import jax
import jax.numpy as jnp

Now we can use jnp in place of np for the usual array operations:

a = jnp.asarray((1.0, 3.2, -1.5))
print(a)
[ 1.   3.2 -1.5]
print(jnp.sum(a))
2.7
print(jnp.mean(a))
0.90000004
print(jnp.dot(a, a))
13.490001

However, the array object a is not a NumPy array:

a
Array([ 1. ,  3.2, -1.5], dtype=float32)
type(a)
jaxlib.xla_extension.ArrayImpl

Even scalar-valued maps on arrays return JAX arrays.

jnp.sum(a)
Array(2.7, dtype=float32)

JAX arrays are also called “device arrays,” where term “device” refers to a hardware accelerator (GPU or TPU).

(In the terminology of GPUs, the “host” is the machine that launches GPU operations, while the “device” is the GPU itself.)

Operations on higher dimensional arrays are also similar to NumPy:

A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B
Array([[1., 1.],
       [1., 1.]], dtype=float32)
from jax.numpy import linalg
linalg.inv(B)   # Inverse of identity is identity
Array([[1., 0.],
       [0., 1.]], dtype=float32)
linalg.eigh(B)  # Computes eigenvalues and eigenvectors
EighResult(eigenvalues=Array([1., 1.], dtype=float32), eigenvectors=Array([[1., 0.],
       [0., 1.]], dtype=float32))

1.1.2. Differences#

One difference between NumPy and JAX is that JAX currently uses 32 bit floats by default.

This is standard for GPU computing and can lead to significant speed gains with small loss of precision.

However, for some calculations precision matters. In these cases 64 bit floats can be enforced via the command

jax.config.update("jax_enable_x64", True)

Let’s check this works:

jnp.ones(3)
Array([1., 1., 1.], dtype=float64)

As a NumPy replacement, a more significant difference is that arrays are treated as immutable.

For example, with NumPy we can write

import numpy as np
a = np.linspace(0, 1, 3)
a
array([0. , 0.5, 1. ])

and then mutate the data in memory:

a[0] = 1
a
array([1. , 0.5, 1. ])

In JAX this fails:

a = jnp.linspace(0, 1, 3)
a
Array([0. , 0.5, 1. ], dtype=float64)
a[0] = 1
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 1
----> 1 a[0] = 1

File /usr/share/miniconda3/envs/jt/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:278, in _unimplemented_setitem(self, i, x)
    273 def _unimplemented_setitem(self, i, x):
    274   msg = ("'{}' object does not support item assignment. JAX arrays are "
    275          "immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` "
    276          "or another .at[] method: "
    277          "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html")
--> 278   raise TypeError(msg.format(type(self)))

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In line with immutability, JAX does not support inplace operations:

a = np.array((2, 1))
a.sort()
a
array([1, 2])
a = jnp.array((2, 1))
a_new = a.sort()
a, a_new
(Array([2, 1], dtype=int64), Array([1, 2], dtype=int64))

The designers of JAX chose to make arrays immutable because JAX uses a functional programming style. More on this below.

Note that, while mutation is discouraged, it is in fact possible with at, as in

a = jnp.linspace(0, 1, 3)
id(a)
74062384
a
Array([0. , 0.5, 1. ], dtype=float64)
a.at[0].set(1)
Array([1. , 0.5, 1. ], dtype=float64)

We can check that the array is mutated by verifying its identity is unchanged:

id(a)
74062384

1.2. Random Numbers#

Random numbers are also a bit different in JAX, relative to NumPy. Typically, in JAX, the state of the random number generator needs to be controlled explicitly.

import jax.random as random

First we produce a key, which seeds the random number generator.

key = random.PRNGKey(1)
type(key)
jaxlib.xla_extension.ArrayImpl
print(key)
[0 1]

Now we can use the key to generate some random numbers:

x = random.normal(key, (3, 3))
x
Array([[-1.35247421, -0.2712502 , -0.02920518],
       [ 0.34706456,  0.5464053 , -1.52325812],
       [ 0.41677264, -0.59710138, -0.5678208 ]], dtype=float64)

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

random.normal(key, (3, 3))
Array([[-1.35247421, -0.2712502 , -0.02920518],
       [ 0.34706456,  0.5464053 , -1.52325812],
       [ 0.41677264, -0.59710138, -0.5678208 ]], dtype=float64)

To produce a (quasi-) independent draw, best practice is to “split” the existing key:

key, subkey = random.split(key)
random.normal(key, (3, 3))
Array([[ 1.85374374, -0.37683949, -0.61276867],
       [-1.91829718,  0.27219409,  0.54922246],
       [ 0.40451442, -0.58726839, -0.63967753]], dtype=float64)
random.normal(subkey, (3, 3))
Array([[-0.4300635 ,  0.22778552,  0.57241269],
       [-0.15969178,  0.46719192,  0.21165091],
       [ 0.84118631,  1.18671326, -0.16607783]], dtype=float64)

The function below produces k (quasi-) independent random n x n matrices using this procedure.

def gen_random_matrices(key, n, k):
    matrices = []
    for _ in range(k):
        key, subkey = random.split(key)
        matrices.append(random.uniform(subkey, (n, n)))
    return matrices
matrices = gen_random_matrices(key, 2, 2)
for A in matrices:
    print(A)
[[0.97440813 0.3838544 ]
 [0.9790686  0.99981046]]
[[0.3473302  0.17157842]
 [0.89346686 0.01403153]]

One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use (len, ) for the shape, as in

random.normal(key, (5, ))
Array([-0.64377279,  0.76961857, -0.29809604,  0.47858776, -2.00591299],      dtype=float64)

In conclusion, this introductory lecture has provided a glimpse into the capabilities and features of JAX. We’ve explored how JAX offers a powerful platform for numerical computing and seamless interoperability with NumPy as well as how it differs from NumPy.

1.3. References#