JAX patterns#
Since Numpyro uses JAX [Bradbury et al., 2018] as a backend, it is important to know how to work with JAX efficiently. JAX is a high-performance numerical computing library designed for machine learning research, which leverages XLA (Accelerated Linear Algebra) for optimizing and compiling numerical computations. This combination enables JAX to efficiently execute large-scale mathematical operations on hardware accelerators, boosting performance and scalability.
Array assignments#
Even thought, jax.numpy as jnp
is meant to replicate the functionality of numpy as np
, and jnp
would behave in many cases indeed indistinguishably from np
, there is a few differences. The reason for that is that jax.numpy
provides nearly the same API as numpy
, it uses a different (JAX) backend.
JAX arrays are immutable. I.e. in JAX, array indexing adapts NumPy’s syntax to fit immutable arrays. Direct assignments like array[index] = value
are not supported. Instead, JAX uses the .at
method for updates, allowing modifications in a purely functional style. For example, setting value of an element is done with
array = array.at[index].set(value)
,
and incrementing an element uses
array = array.at[index].add(value)
.
This method returns a new array with the desired change, maintaining the original array unchanged, crucial for JAX’s efficiency in automatic differentiation and optimisation.
vmap
#
vmap
in JAX is a vectorizing map that automatically transforms a function to apply it over batched inputs, effectively parallelizing computations across data. This tool helps simplify the process of extending single-data-point functions to operate on batches, improving efficiency and performance without manual adjustments to the code.
vmap
returns a function which is applied acorss the leading axis of an array.
Let’s say you have a simple function that computes the square of a number. Using vmap
, you can easily extend this function to square an entire array of numbers in one go.
import jax
import jax.numpy as jnp
from jax import vmap
from jax import lax
def square(x):
return x * x
array = jnp.array([jnp.repeat(1,5), jnp.repeat(2,5)])
squared_array = vmap(square)(array)
print(squared_array)
[[1 1 1 1 1]
[4 4 4 4 4]]
We can use the option in_axes
to specify along which axis to apply the function. Not the difference with the axes
option for the in-built operations in numpy
and jax.numpy:
# here `axis=k` says which dimension to collapse
print(jnp.sum(array, axis=0))
print(jnp.sum(array, axis=1))
[3 3 3 3 3]
[ 5 10]
def sum_array(x):
return jnp.sum(x)
# default behavior is to sum along the first axis
print(jax.vmap(sum_array, in_axes=0)(array))
print(jax.vmap(sum_array, in_axes = 1)(array))
[ 5 10]
[3 3 3 3 3]
Consider a scenario where you have a function that calculates the Euclidean distance between two points, but you want to apply this function across pairs of points stored in two separate arrays. This function involves more complex operations, including subtraction and squaring, which makes vmap
particularly beneficial for vectorising such computations efficiently over batches.
def euclidean_distance(x, y):
return jnp.sqrt(jnp.sum((x - y) ** 2))
points_1 = jnp.array([[1, 2], [3, 4], [5, 6]])
points_2 = jnp.array([[6, 5], [4, 3], [2, 1]])
distances = vmap(euclidean_distance, in_axes=(0, 0))(points_1, points_2)
print(distances)
[5.8309517 1.4142135 5.8309517]
The in_axes=(0, 0)
argument tells vmap
to apply the function across the first dimension (rows) of both inputs.
Let’s look at nother example of a multidimensional input.
def add_and_multiply_scalar(x, y, scalar):
return (x + y) * scalar
def add_and_multiply_vector(x, y, scalar):
return vmap(add_and_multiply_scalar, (0, 0, None))(x, y, scalar)
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
scalar = jnp.array([2.0, 3.0, 4.0])
result = add_and_multiply_vector(x, y, scalar)
print(result)
[[10. 15. 20.]
[14. 21. 28.]
[18. 27. 36.]]
Here (0, 0, None)
tells vmap
to apply the function add_and_multiply_scalar
element-wise across the first dimension of x
and y
, while keeping the scalar
fixed for each corresponding element-wise operation.
vmap
applies a function to the elements of the leading axis of an array indendently. What if there is some carry over quantities that we need to use for the next copmutation based on last? Then we use lax.scan
.
lax.scan
#
lax.scan
is a function used for looping over sequences while carrying state across iterations. Its inputs are as follows:
an input sequences that you want to iterate over: a list, tuple or an array,
an initial state which is carried through each iteration of the loop,
a function that takes the current state and an element from the input sequence and returns a tuple: the new state for the next iteration and the output for the current iteration,
lax.scan
then iterates over the sequence, applying the function at each step, and updating the state accordingly. It collects the outputs from each iteration into a sequence. It returns both the sequence of outputs and the final state after iterating over all elements in the input sequences.
Example: cumulative sum#
def cumsum_fn(prev_sum, next_value):
return prev_sum + next_value, prev_sum + next_value
sequence = jnp.array([1, 2, 3, 4, 5])
initial_sum = jnp.array(0)
cumulative_sum, _ = lax.scan(cumsum_fn, initial_sum, sequence)
print("Cumulative Sum:", cumulative_sum)
Cumulative Sum: 15
Example: moving average#
def moving_average(state, next_value):
prev_avg, _ = state
new_avg = (prev_avg * 2 + next_value) / 3
return (new_avg, next_value), new_avg
initial_state = (0.0, 0.0)
sensor_data = jnp.array([10.0, 12.0, 14.0, 16.0, 18.0])
final_state, moving_averages = lax.scan(moving_average, initial_state, sensor_data)
print("Moving averages:", moving_averages)
Moving averages: [ 3.3333335 6.222223 8.8148155 11.209877 13.473251 ]
Random keys with jax.random.PRNGKey(seed)
#
We have already used at a few instances JAX’s pseudorandom numbers generator. jax.random.PRNGKey(seed)
in JAX creates a random number generator key using a given seed (an integer). This key is used for reproducible and functional-style random number generation. JAX’s random system is stateless, so you need to manage and pass this key explicitly to any function that uses randomness. You can use this key to generate random numbers or split it into sub-keys:
key = jax.random.PRNGKey(42)
print(key)
[ 0 42]
subkey1, subkey2 = jax.random.split(key)
print(subkey1, subkey2)
[2465931498 3679230171] [255383827 267815257]
Splitting into multiple keys can be done using the num
argument:
keys = jax.random.split(key, num=5)
print(keys)
[[2765691542 1385194879]
[ 831049250 3807460095]
[3616728933 824333390]
[1482326074 1083977345]
[2713995981 2812206153]]