JAX patterns#

Since Numpyro uses JAX [Bradbury et al., 2018] as a backend, it is important to know how to work with JAX efficiently. JAX is a high-performance numerical computing library designed for machine learning research, which leverages XLA (Accelerated Linear Algebra) for optimizing and compiling numerical computations. This combination enables JAX to efficiently execute large-scale mathematical operations on hardware accelerators, boosting performance and scalability.

Array assignments#

Even thought, jax.numpy as jnp is meant to replicate the functionality of numpy as np, and jnp would behave in many cases indeed indistinguishably from np, there is a few differences. The reason for that is that jax.numpy provides nearly the same API as numpy, it uses a different (JAX) backend.

JAX arrays are immutable. I.e. in JAX, array indexing adapts NumPy’s syntax to fit immutable arrays. Direct assignments like array[index] = value are not supported. Instead, JAX uses the .at method for updates, allowing modifications in a purely functional style. For example, setting value of an element is done with

array = array.at[index].set(value),

and incrementing an element uses

array = array.at[index].add(value).

This method returns a new array with the desired change, maintaining the original array unchanged, crucial for JAX’s efficiency in automatic differentiation and optimisation.

vmap#

vmap in JAX is a vectorizing map that automatically transforms a function to apply it over batched inputs, effectively parallelizing computations across data. This tool helps simplify the process of extending single-data-point functions to operate on batches, improving efficiency and performance without manual adjustments to the code.

vmap returns a function which is applied acorss the leading axis of an array.

Let’s say you have a simple function that computes the square of a number. Using vmap, you can easily extend this function to square an entire array of numbers in one go.

import jax
import jax.numpy as jnp
from jax import vmap
from jax import lax 


def square(x):
    return x * x

array = jnp.array([jnp.repeat(1,5), jnp.repeat(2,5)])

squared_array = vmap(square)(array)

print(squared_array)
[[1 1 1 1 1]
 [4 4 4 4 4]]

We can use the option in_axes to specify along which axis to apply the function. Not the difference with the axes option for the in-built operations in numpy and jax.numpy:

# here `axis=k` says which dimension to collapse 
print(jnp.sum(array, axis=0))

print(jnp.sum(array, axis=1))
[3 3 3 3 3]
[ 5 10]
def sum_array(x):
    return jnp.sum(x)

# default behavior is to sum along the first axis
print(jax.vmap(sum_array, in_axes=0)(array))

print(jax.vmap(sum_array, in_axes = 1)(array))
[ 5 10]
[3 3 3 3 3]

Consider a scenario where you have a function that calculates the Euclidean distance between two points, but you want to apply this function across pairs of points stored in two separate arrays. This function involves more complex operations, including subtraction and squaring, which makes vmap particularly beneficial for vectorising such computations efficiently over batches.

def euclidean_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))

points_1 = jnp.array([[1, 2], [3, 4], [5, 6]])
points_2 = jnp.array([[6, 5], [4, 3], [2, 1]])

distances = vmap(euclidean_distance, in_axes=(0, 0))(points_1, points_2)

print(distances)
[5.8309517 1.4142135 5.8309517]

The in_axes=(0, 0) argument tells vmap to apply the function across the first dimension (rows) of both inputs.

Let’s look at nother example of a multidimensional input.

def add_and_multiply_scalar(x, y, scalar):
    return (x + y) * scalar

def add_and_multiply_vector(x, y, scalar):
    return vmap(add_and_multiply_scalar, (0, 0, None))(x, y, scalar)

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
scalar = jnp.array([2.0, 3.0, 4.0])

result = add_and_multiply_vector(x, y, scalar)
print(result)
[[10. 15. 20.]
 [14. 21. 28.]
 [18. 27. 36.]]

Here (0, 0, None) tells vmap to apply the function add_and_multiply_scalar element-wise across the first dimension of x and y, while keeping the scalar fixed for each corresponding element-wise operation.

vmap applies a function to the elements of the leading axis of an array indendently. What if there is some carry over quantities that we need to use for the next copmutation based on last? Then we use lax.scan.

lax.scan#

lax.scan is a function used for looping over sequences while carrying state across iterations. Its inputs are as follows:

  • an input sequences that you want to iterate over: a list, tuple or an array,

  • an initial state which is carried through each iteration of the loop,

  • a function that takes the current state and an element from the input sequence and returns a tuple: the new state for the next iteration and the output for the current iteration,

lax.scan then iterates over the sequence, applying the function at each step, and updating the state accordingly. It collects the outputs from each iteration into a sequence. It returns both the sequence of outputs and the final state after iterating over all elements in the input sequences.

Example: cumulative sum#

def cumsum_fn(prev_sum, next_value):
    return prev_sum + next_value, prev_sum + next_value

sequence = jnp.array([1, 2, 3, 4, 5])

initial_sum = jnp.array(0)

cumulative_sum, _ = lax.scan(cumsum_fn, initial_sum, sequence)

print("Cumulative Sum:", cumulative_sum)
Cumulative Sum: 15

Example: moving average#

def moving_average(state, next_value):
    prev_avg, _ = state  
    new_avg = (prev_avg * 2 + next_value) / 3
    return (new_avg, next_value), new_avg

initial_state = (0.0, 0.0)

sensor_data = jnp.array([10.0, 12.0, 14.0, 16.0, 18.0])

final_state, moving_averages = lax.scan(moving_average, initial_state, sensor_data)

print("Moving averages:", moving_averages)
Moving averages: [ 3.3333335  6.222223   8.8148155 11.209877  13.473251 ]

Random keys with jax.random.PRNGKey(seed)#

We have already used at a few instances JAX’s pseudorandom numbers generator. jax.random.PRNGKey(seed) in JAX creates a random number generator key using a given seed (an integer). This key is used for reproducible and functional-style random number generation. JAX’s random system is stateless, so you need to manage and pass this key explicitly to any function that uses randomness. You can use this key to generate random numbers or split it into sub-keys:

key = jax.random.PRNGKey(42)
print(key)
[ 0 42]
subkey1, subkey2 = jax.random.split(key)
print(subkey1, subkey2)
[2465931498 3679230171] [255383827 267815257]

Splitting into multiple keys can be done using the num argument:

keys = jax.random.split(key, num=5) 
print(keys)
[[2765691542 1385194879]
 [ 831049250 3807460095]
 [3616728933  824333390]
 [1482326074 1083977345]
 [2713995981 2812206153]]