# JAX patterns#

Since Numpyro uses JAX as a backend, it is importnat to know how to work with JAX efficiently. JAX is a high-performance numerical computing library designed for machine learning research, which leverages XLA (Accelerated Linear Algebra) for optimizing and compiling numerical computations. This combination enables JAX to efficiently execute large-scale mathematical operations on hardware accelerators, boosting performance and scalability.

## Array assignments#

Even thought, `jax.numpy as jnp`

is meant to replicate the functionality of `numpy as np`

, and `jnp`

would behave in many cases indeed indistinguishably from `np`

, there is a few differences. Assignments in array is one them: JAX arrays are immutable.

In JAX, array indexing adapts NumPy’s syntax to fit immutable arrays. Direct assignments like `array[index] = value`

are not supported. Instead, JAX uses the `.at`

method for updates, allowing modifications in a purely functional style. For example, setting an element is done with

`array = array.at[index].set(value)`

,

and incrementing an element uses

`array = array.at[index].add(value)`

.

This method returns a new array with the desired change, maintaining the original array unchanged, crucial for JAX’s efficiency in automatic differentiation and optimization.

`vmap`

#

`vmap`

in JAX is a vectorizing map that automatically transforms a function to apply it over batched inputs, effectively parallelizing computations across data. This tool helps simplify the process of extending single-data-point functions to operate on batches, improving efficiency and performance without manual adjustments to the code.

`vmap`

returns a function which is applied acorss the leading axis of an array.

Let’s say you have a simple function that computes the square of a number. Using `vmap`

, you can easily extend this function to square an entire array of numbers in one go.

```
import jax
import jax.numpy as jnp
from jax import vmap
from jax import lax
def square(x):
return x * x
array = jnp.array([jnp.repeat(1,5), jnp.repeat(2,5)])
squared_array = vmap(square)(array)
print(squared_array)
```

```
[[1 1 1 1 1]
[4 4 4 4 4]]
```

We can use the option `in_axes`

to specify along whch axis to apply the function. Not the difference with the `axes`

option for the in-built operations in `numpy`

and `jax.numpy:`

```
# here `axis=k` says which dimension to collapse
print(jnp.sum(array, axis=0))
print(jnp.sum(array, axis=1))
```

```
[3 3 3 3 3]
[ 5 10]
```

```
def sum_array(x):
return jnp.sum(x)
# default behavior is to sum along the first axis
print(jax.vmap(sum_array, in_axes=0)(array))
print(jax.vmap(sum_array, in_axes = 1)(array))
```

```
[ 5 10]
[3 3 3 3 3]
```

Consider a scenario where you have a function that calculates the Euclidean distance between two points, but you want to apply this function across pairs of points stored in two separate arrays. This function involves more complex operations, including subtraction and squaring, which makes `vmap`

particularly beneficial for vectorizing such computations efficiently over batches.

```
def euclidean_distance(x, y):
return jnp.sqrt(jnp.sum((x - y) ** 2))
points_1 = jnp.array([[1, 2], [3, 4], [5, 6]])
points_2 = jnp.array([[6, 5], [4, 3], [2, 1]])
distances = vmap(euclidean_distance, in_axes=(0, 0))(points_1, points_2)
print(distances)
```

```
[5.8309517 1.4142135 5.8309517]
```

The `in_axes=(0, 0)`

argument tells vmap to apply the function across the first dimension (rows) of both inputs.

Let’s look at nother example of a multidimensional input.

```
def add_and_multiply_scalar(x, y, scalar):
return (x + y) * scalar
def add_and_multiply_vector(x, y, scalar):
return vmap(add_and_multiply_scalar, (0, 0, None))(x, y, scalar)
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
scalar = jnp.array([2.0, 3.0, 4.0])
result = add_and_multiply_vector(x, y, scalar)
print(result)
```

```
[[10. 15. 20.]
[14. 21. 28.]
[18. 27. 36.]]
```

Here `(0, 0, None)`

tells `vmap`

to apply the function `add_and_multiply_scalar`

element-wise across the first dimension of `x`

and `y`

, while keeping the `scalar`

fixed for each corresponding element-wise operation.

`vmap`

applies a function to the elements of th leading axis of an array indendently. What if there is some carry over quantities that we need to use for the next copmutation based on last? Then we use `lax.scan`

.

`lax.scan`

#

`lax.scan`

is a function used for looping over sequences while carrying state across iterations. Its inputs are as follows:

an input sequences that you want to iterate over: a list, tuple or an array,

an initial state which is carried through each iteration of the loop,

a function that takes the current state and an element from the input sequence and returns a tuple: the new state for the next iteration and the output for the current iteration,

`lax.scan`

then iterates over the sequence, applying the function at each step, and updating the state accordingly. It collects the outputs from each iteration into a sequence. It returns both the sequence of outputs and the final state after iterating over all elements in the input sequences.

## Example: cumulative sum#

```
def cumsum_fn(prev_sum, next_value):
return prev_sum + next_value, prev_sum + next_value
sequence = jnp.array([1, 2, 3, 4, 5])
initial_sum = jnp.array(0)
cumulative_sum, _ = lax.scan(cumsum_fn, initial_sum, sequence)
print("Cumulative Sum:", cumulative_sum)
```

```
Cumulative Sum: 15
```

## Example: moving average#

```
def moving_average(state, next_value):
prev_avg, _ = state
new_avg = (prev_avg * 2 + next_value) / 3
return (new_avg, next_value), new_avg
initial_state = (0.0, 0.0)
sensor_data = jnp.array([10.0, 12.0, 14.0, 16.0, 18.0])
final_state, moving_averages = lax.scan(moving_average, initial_state, sensor_data)
print("Moving averages:", moving_averages)
```

```
Moving averages: [ 3.3333335 6.222223 8.8148155 11.209877 13.473251 ]
```