JAX patterns#

Since Numpyro uses JAX as a backend, it is importnat to know how to work with JAX efficiently. JAX is a high-performance numerical computing library designed for machine learning research, which leverages XLA (Accelerated Linear Algebra) for optimizing and compiling numerical computations. This combination enables JAX to efficiently execute large-scale mathematical operations on hardware accelerators, boosting performance and scalability.

Array assignments#

Even thought, jax.numpy as jnp is meant to replicate the functionality of numpy as np, and jnp would behave in many cases indeed indistinguishably from np, there is a few differences. Assignments in array is one them: JAX arrays are immutable.

In JAX, array indexing adapts NumPy’s syntax to fit immutable arrays. Direct assignments like array[index] = value are not supported. Instead, JAX uses the .at method for updates, allowing modifications in a purely functional style. For example, setting an element is done with

array = array.at[index].set(value),

and incrementing an element uses

array = array.at[index].add(value).

This method returns a new array with the desired change, maintaining the original array unchanged, crucial for JAX’s efficiency in automatic differentiation and optimization.


vmap in JAX is a vectorizing map that automatically transforms a function to apply it over batched inputs, effectively parallelizing computations across data. This tool helps simplify the process of extending single-data-point functions to operate on batches, improving efficiency and performance without manual adjustments to the code.

vmap returns a function which is applied acorss the leading axis of an array.

Let’s say you have a simple function that computes the square of a number. Using vmap, you can easily extend this function to square an entire array of numbers in one go.

import jax
import jax.numpy as jnp
from jax import vmap
from jax import lax 

def square(x):
    return x * x

array = jnp.array([jnp.repeat(1,5), jnp.repeat(2,5)])

squared_array = vmap(square)(array)

[[1 1 1 1 1]
 [4 4 4 4 4]]

We can use the option in_axes to specify along whch axis to apply the function. Not the difference with the axes option for the in-built operations in numpy and jax.numpy:

# here `axis=k` says which dimension to collapse 
print(jnp.sum(array, axis=0))

print(jnp.sum(array, axis=1))
[3 3 3 3 3]
[ 5 10]
def sum_array(x):
    return jnp.sum(x)

# default behavior is to sum along the first axis
print(jax.vmap(sum_array, in_axes=0)(array))

print(jax.vmap(sum_array, in_axes = 1)(array))
[ 5 10]
[3 3 3 3 3]

Consider a scenario where you have a function that calculates the Euclidean distance between two points, but you want to apply this function across pairs of points stored in two separate arrays. This function involves more complex operations, including subtraction and squaring, which makes vmap particularly beneficial for vectorizing such computations efficiently over batches.

def euclidean_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))

points_1 = jnp.array([[1, 2], [3, 4], [5, 6]])
points_2 = jnp.array([[6, 5], [4, 3], [2, 1]])

distances = vmap(euclidean_distance, in_axes=(0, 0))(points_1, points_2)

[5.8309517 1.4142135 5.8309517]

The in_axes=(0, 0) argument tells vmap to apply the function across the first dimension (rows) of both inputs.

Let’s look at nother example of a multidimensional input.

def add_and_multiply_scalar(x, y, scalar):
    return (x + y) * scalar

def add_and_multiply_vector(x, y, scalar):
    return vmap(add_and_multiply_scalar, (0, 0, None))(x, y, scalar)

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
scalar = jnp.array([2.0, 3.0, 4.0])

result = add_and_multiply_vector(x, y, scalar)
[[10. 15. 20.]
 [14. 21. 28.]
 [18. 27. 36.]]

Here (0, 0, None) tells vmap to apply the function add_and_multiply_scalar element-wise across the first dimension of x and y, while keeping the scalar fixed for each corresponding element-wise operation.

vmap applies a function to the elements of th leading axis of an array indendently. What if there is some carry over quantities that we need to use for the next copmutation based on last? Then we use lax.scan.


lax.scan is a function used for looping over sequences while carrying state across iterations. Its inputs are as follows:

  • an input sequences that you want to iterate over: a list, tuple or an array,

  • an initial state which is carried through each iteration of the loop,

  • a function that takes the current state and an element from the input sequence and returns a tuple: the new state for the next iteration and the output for the current iteration,

lax.scan then iterates over the sequence, applying the function at each step, and updating the state accordingly. It collects the outputs from each iteration into a sequence. It returns both the sequence of outputs and the final state after iterating over all elements in the input sequences.

Example: cumulative sum#

def cumsum_fn(prev_sum, next_value):
    return prev_sum + next_value, prev_sum + next_value

sequence = jnp.array([1, 2, 3, 4, 5])

initial_sum = jnp.array(0)

cumulative_sum, _ = lax.scan(cumsum_fn, initial_sum, sequence)

print("Cumulative Sum:", cumulative_sum)
Cumulative Sum: 15

Example: moving average#

def moving_average(state, next_value):
    prev_avg, _ = state  
    new_avg = (prev_avg * 2 + next_value) / 3
    return (new_avg, next_value), new_avg

initial_state = (0.0, 0.0)

sensor_data = jnp.array([10.0, 12.0, 14.0, 16.0, 18.0])

final_state, moving_averages = lax.scan(moving_average, initial_state, sensor_data)

print("Moving averages:", moving_averages)
Moving averages: [ 3.3333335  6.222223   8.8148155 11.209877  13.473251 ]