Agent-based models (ABMs)#

Agent-Based Models (ABMs), or “individual-based models”, are used to simulate the behaviors and interactions of individuals, or “agents”. In epidemiology, ABMs are particularly useful for understanding the dynamics of infectious disease transmission, as they can capture heterogeneity in individual behavior, such as mobility, susceptibility and other. Unlike compartmental models, which rely on aggregate populations, ABMs represent each individual as a discrete agent whose state evolves based on predefined rules and interactions with others. These models enable us to explore how micro-level processes, such as social contact patterns or intervention strategies, influence macro-level outcomes, such as epidemic spread or herd immunity. By incorporating randomness at individual level, ABMs provide a flexible and granular approach to studying complex epidemiological phenomena and testing public health interventions.

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

import numpy as np
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

SI-ABM: simulation#

Let us construct a forward simulation with two possible states - susseptible (S) and infected (I). We assign initial_infections randomly following the Bernoulli distribution an initial probability of being infected initial_infected.

def abm_SI(num_agents, num_steps, p_infect, initial_infected):
    states = jnp.zeros(num_agents)  
    initial_infections = jax.random.bernoulli(jax.random.PRNGKey(0), initial_infected, (num_agents,))
    states = states + initial_infections  # initialize some as infected

    S_list = []
    I_list = []

    for step in range(num_steps):
        S_list.append(jnp.sum(1 - states))  # count susceptible agents - we don't need these for inference, but for viz
        I_list.append(jnp.sum(states))      # count infected agents - we don't need these for inference, but for viz

        random_vals = jax.random.uniform(jax.random.PRNGKey(step), (num_agents,))
        new_infections = (random_vals < p_infect) * (1 - states)
        states = states + new_infections

    S_list.append(jnp.sum(1 - states))  # final count of susceptibles - don't need for inference, but for viz
    I_list.append(jnp.sum(states))      # final count of infected - don't need for inference, but for viz

    return jnp.clip(states, 0, 1), S_list, I_list
# synthetic data
num_agents = 100
num_steps = 10
true_p_infect = 0.1
true_initial_infected = 0.05

#run forward the ABM S-I model
final_states, S_counts, I_counts = abm_SI(num_agents, num_steps, true_p_infect, true_initial_infected)

# plot S and I over time
plt.figure(figsize=(8, 5))
plt.plot(range(num_steps + 1), S_counts, label="Susceptible (S)", marker='o')
plt.plot(range(num_steps + 1), I_counts, label="Infected (I)", marker='x')
plt.xlabel("Time (steps)")
plt.ylabel("Number of Agents")
plt.title("S and I over Time in ABM S-I Model")
plt.legend()
plt.grid(True)
plt.show()
_images/b9bcfdc2d17da52e1903a0a32b8ff6f2d0e0eb856c37d4e474da8e4a36b4d56e.png

Group Task

Experiment with the following inputs in the above model:

  • increase and disease the number of time steps num_steps. How does that change the result?

  • increase and disease the infections probability true_p_infect. How does that change the result?

  • increase and disease the number of agents num_agents. How does that affect the result?

SI-ABM: inference#

There is a lot of flexibility that ABM simulations allow. But can we perform inference for them as easily as we did for ODE-based models? Yes! Thanks to automatic differentiation.

def numpyro_abm_SI(num_agents, num_steps, data=None):

    p_infect = numpyro.sample("p_infect", dist.Beta(2, 6))
    initial_infected = numpyro.sample("initial_infected", dist.Beta(2, 10))

    final_states, S_t, I_t = abm_SI(num_agents, num_steps, p_infect, initial_infected)

    if data is not None:
        numpyro.sample("obs", dist.Bernoulli(final_states), obs=data)
# inference
nuts_kernel = NUTS(numpyro_abm_SI)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), num_agents=num_agents, num_steps=num_steps, data=final_states)
p_infect_prior = jax.random.beta(jax.random.PRNGKey(0), 2, 6, (1000,))
initial_infected_prior = jax.random.beta(jax.random.PRNGKey(0), 2, 10, (1000,))

samples = mcmc.get_samples()
p_infect_samps = samples['p_infect']
initial_infected_samps = samples['initial_infected']

plt.figure(figsize=(8, 3))

plt.subplot(1, 2, 1)
sns.kdeplot(p_infect_prior, fill=True, label="prior")
sns.kdeplot(p_infect_samps, fill=True, label="posterior")
plt.axvline(x=true_p_infect, color='red', linestyle='--', label='True p_infect ')
plt.xlabel("p_infect")
plt.ylabel("density")
plt.title("Posterior distribution of p_infect")
plt.legend()

plt.subplot(1, 2, 2)
sns.kdeplot(initial_infected_prior, fill=True, label="prior")
sns.kdeplot(initial_infected_samps, fill=True, label="posterior")
plt.axvline(x=true_initial_infected, color='red', linestyle='--', label='True initial_infected')
plt.xlabel("initial_infected")
plt.ylabel("density")
plt.title("Posterior fistribution of initial_infected")
plt.legend()

plt.tight_layout()
plt.show()
_images/fd5351e4316b94e145d8ffb7a079350df551b04c1f0c6cda46287228ed68e4d8.png
num_posterior_draws = len(p_infect_samps)

# store all S_t and I_t for each posterior draw
S_t_draws = []
I_t_draws = []

for i in range(num_posterior_draws):
    _, S_t, I_t = abm_SI(num_agents, num_steps, p_infect_samps[i], initial_infected_samps[i])
    S_t_draws.append(S_t)
    I_t_draws.append(I_t)

# lists to arrays
S_t_draws = jnp.array(S_t_draws)
I_t_draws = jnp.array(I_t_draws)

#  mean and credible intervals
S_t_mean = S_t_draws.mean(axis=0)
S_t_lower = jnp.percentile(S_t_draws, 5, axis=0)
S_t_upper = jnp.percentile(S_t_draws, 95, axis=0)

I_t_mean = I_t_draws.mean(axis=0)
I_t_lower = jnp.percentile(I_t_draws, 5, axis=0)
I_t_upper = jnp.percentile(I_t_draws, 95, axis=0)

# plot
time_steps = range(num_steps + 1)
plt.figure(figsize=(12, 6))
plt.plot(time_steps, S_t_mean, label="Mean $S_t$", color="purple")
plt.fill_between(time_steps, S_t_lower, S_t_upper, color="purple", alpha=0.3, label="90% CI $S_t$")   

plt.plot(time_steps, I_t_mean, label="Mean $I_t$", color="orangered")
plt.fill_between(time_steps, I_t_lower, I_t_upper, color="orangered", alpha=0.3, label="90% CI $I_t$")

plt.xlabel("time steps")
plt.ylabel("Number of agents")
plt.title("Mean and 90% CI of $S_t$ and $I_t$ of an SI ABM model")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
_images/eacf335f3c7baa6dff02a0cd0e0d9a44af23090c580968da03855fe613fbf255.png

Task

Repeat the procedure shown above, i.e. 1) simulation, 2) inference, 3) visualisation of trajectiories for the ABM SIR model.

Spatial SI-ABM on a lattice#

To extend the model to a spatial one, on a lattice, we represent agents as nodes on a 2D grid (lattice). Each agent interacts only with its neighbors, and the infection spreads spatially based on proximity.

def abm_spatial_SI(grid_size, num_steps, p_infect, initial_infected):

    # initialize states on the square grid
    num_agents = grid_size ** 2
    states = jnp.zeros((grid_size, grid_size))  
    initial_infections = jax.random.bernoulli(jax.random.PRNGKey(0), initial_infected, (grid_size, grid_size))
    states = states + initial_infections  # initial infections

    # tracking S and I over time
    S_list = []
    I_list = []

    # neighborhood offsets for Moore neighborhood (8 neighbors)
    #neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
    neighbors = [(-1, -1),  (-1, 1),  (1, -1),  (1, 1)]

    for step in range(num_steps):
        S_list.append(jnp.sum(1 - states))  # count susceptibles 
        I_list.append(jnp.sum(states))      # count infected agents

        # infection spreading
        new_states = states.copy()
        for i in range(grid_size):
            for j in range(grid_size):
                if states[i, j] == 1:  # already infected
                    continue
                # check neighbors for infection
                for dx, dy in neighbors:
                    ni, nj = i + dx, j + dy
                    if 0 <= ni < grid_size and 0 <= nj < grid_size:  # valid neighbor
                        if states[ni, nj] == 1:  # neighbor is infected
                            if jax.random.uniform(jax.random.PRNGKey(step + i + j)) < p_infect:
                                new_states = new_states.at[i, j].set(1)
                                break  # agent gets infected, no need to check more neighbors
        states = new_states

    S_list.append(jnp.sum(1 - states))  
    I_list.append(jnp.sum(states))      

    return states, S_list, I_list

# parameters
grid_size = 20
num_steps = 10
true_p_infect = 0.2
true_initial_infected = 0.1

# run forward spatial ABM S-I model
final_states, S_counts, I_counts = abm_spatial_SI(grid_size, num_steps, true_p_infect, true_initial_infected)
colors = ["darkorange", "lightblue"]

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# final spatial state
cmap = ListedColormap(colors)  
im = axes[0].imshow(final_states, cmap=cmap, interpolation="nearest")
axes[0].set_title("Final spatial state of the grid")

# S_t and I_t over time 
axes[1].plot(range(num_steps + 1), S_counts, label="Susceptible $S_t$", marker='o', color=colors[1])
axes[1].plot(range(num_steps + 1), I_counts, label="Infected $I_t$", marker='x', color=colors[0])
axes[1].set_xlabel("time steps")
axes[1].set_ylabel("number of agents")
axes[1].set_title("$S_t$ and $I_t$ in lattice SI-ABM model")
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()
_images/8b0e69df172148ea3f8afb16f95a9329b47ae00aa4b44e6a1db9c5999a69c3cf.png

We have a model implemented in JAX now. But would it work, as is, for HMC inference with Numpyro? Not really!

Let us build a spatial SI-ABM model which would preserve the same logic, but be HMC-friendly.

def abm_spatial_SI(grid_size, num_steps, p_infect, initial_infected):
    """
    A JAX-compatible ABM spatial S-I model.
    """
    states = jnp.zeros((grid_size, grid_size))
    initial_infections = jax.random.bernoulli(jax.random.PRNGKey(0), initial_infected, (grid_size, grid_size))
    states = states + initial_infections  # initial infections

    neighbors = jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]])

    def update_cell(states, pos, key):
        i, j = pos
        key, subkey = jax.random.split(key)

        def infect_neighbor(neighbor_offset, acc_infected):
            ni, nj = i + neighbor_offset[0], j + neighbor_offset[1]
            in_bounds = (ni >= 0) & (nj >= 0) & (ni < grid_size) & (nj < grid_size)
            neighbor_infect = in_bounds & (states[ni, nj] == 1) & (jax.random.uniform(subkey) < p_infect)
            return acc_infected | neighbor_infect

        infected = jax.lax.fori_loop(0, neighbors.shape[0], lambda idx, acc: infect_neighbor(neighbors[idx], acc), False)
        infected = infected.astype(jnp.int32)
        return jax.lax.cond(states[i, j] == 1, lambda _: jnp.int32(1), lambda _: infected, None)

    def update_grid(states, step_key):
        keys = jax.random.split(step_key, grid_size * grid_size).reshape(grid_size, grid_size, 2)
        positions = jnp.array([(i, j) for i in range(grid_size) for j in range(grid_size)])

        def update_single_cell(pos, grid):
            key = keys[pos[0], pos[1]]
            new_value = update_cell(states, pos, key)
            return grid.at[pos[0], pos[1]].set(new_value)

        return jax.lax.fori_loop(0, positions.shape[0], lambda idx, grid: update_single_cell(positions[idx], grid), states)

    def step_fn(states, step_key):
        next_states = update_grid(states, step_key)
        S_t = jnp.sum(1 - next_states)
        I_t = jnp.sum(next_states)
        return next_states, (S_t, I_t)

    keys = jax.random.split(jax.random.PRNGKey(0), num_steps)
    _, (S_list, I_list) = jax.lax.scan(step_fn, states, keys)

    return states, S_list, I_list


# Numpyro model for inference
def numpyro_spatial_abm(grid_size, num_steps, data=None):

    p_infect = numpyro.sample("p_infect", dist.Beta(2, 5))
    initial_infected = numpyro.sample("initial_infected", dist.Beta(1, 10))

    _, _, I_t = abm_spatial_SI(grid_size, num_steps, p_infect, initial_infected)

    if data is not None:
        numpyro.sample("obs", dist.Normal(I_t, 5.0), obs=data)

# synthetic data
grid_size = 10
num_steps = 10
true_p_infect = 0.2
true_initial_infected = 0.1

_, S_counts, I_counts = abm_spatial_SI(grid_size, num_steps, true_p_infect, true_initial_infected)

# inference 
nuts_kernel = NUTS(numpyro_spatial_abm)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), grid_size=grid_size, num_steps=num_steps, data=I_counts)

posterior_samples = mcmc.get_samples()
# plot posterior distributions of parameters
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
sns.kdeplot(posterior_samples["p_infect"], fill=True, color="skyblue")
plt.axvline(x=true_p_infect, color='red', linestyle='--', label="True p_infect")
plt.xlabel("p_infect")
plt.title("Posterior Distribution of p_infect")
plt.legend()

plt.subplot(1, 2, 2)
sns.kdeplot(posterior_samples["initial_infected"], fill=True, color="lightgreen")
plt.axvline(x=true_initial_infected, color='red', linestyle='--', label="True initial_infected")
plt.xlabel("initial_infected")
plt.title("Posterior distribution of initial_infected")
plt.legend()

plt.tight_layout()
plt.show()
_images/c08324cc60e6aedad48829662b1f3127baa3620510908ad3951f6712ebce595e.png