Gaussian processes: inference#

In the previous chapter we have seen how to generate Gaussian process priors. But what we usually would want to do is inference, i.e. given observed data, we would like to estimate model parameters, and, potentially, make predictions at unobserved locations.

Marginalization and Conditioning#

Assume we have two sets of coordinates \(x_n = (x_n^1, x_n^2, ..., x_n^n), x_m = (x_m^1, x_m^2, ..., x_m^m)\) and a GP over them:

\[\begin{split} \begin{bmatrix} f_n \\ f_m \end{bmatrix} \sim \mathcal{N}\left( \underbrace{\begin{bmatrix} \mu_n \\ \mu_m \end{bmatrix}}_{\mu}, \underbrace{\begin{bmatrix} K_{n n} \quad K_{n m} \\ K_{m n} \quad K_{m m} \end{bmatrix}}_{K} \right). \end{split}\]

Marginalization allows us to extract partial information from multivariate probability distribution:

\[\begin{split} \begin{align*} f_n & \sim \mathcal{N}(\mu_n, K_{nn}),\\ f_m & \sim \mathcal{N}(\mu_m, K_{mm}). \end{align*} \end{split}\]

Conditioning allows us to determine the probability of one subset of variables given another subset. Similar to marginalization, this operation is also closed and yields a modified Gaussian distribution:

\[ f_n | f_m \sim \mathcal{N}(\mu_{n|m}, K_{n|m}), \]

where

\[\begin{split} \begin{align*} \mu_{n|m} & = \mu_n + K_{nm} K_{mm}^{-1} (f_m - \mu_m),\\ K_{n|m} & = K_{nn} - K_{nm} K_{mm}^{-1} K_{mn} \end{align*} \end{split}\]

Group Task

Write down these formulas for the \(d=2\) case, i.e. when both \(f_n\) and \(f_m\) have only one component each.

Let us visualise such a conditional.

# imports for this chapter
import numpy as np

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
from numpyro.diagnostics import hpdi

import pickle

numpyro.set_host_device_count(4)  # Set the device count to enable parallel sampling
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
def Gaussian_conditional(mean, cov, x=None,y=None):
  assert not (x is None and y is None) and not (x is not None and y is not None)
  if x is not None:
    var = cov[1,1] -  cov[1,0] * cov[0,0] ** (-1) * cov[0,1]
    mu = mean[1] + cov[1,0] * cov[0,0] ** (-1) * (x - mean[0])
  else:
    var = cov[0,0] -  cov[0,1] * cov[1,1] ** (-1) * cov[1,0]
    mu = mean[0] + cov[0,1] * cov[1,1] ** (-1) * (y - mean[1])
  return mu, var**0.5
# parameters for the 2D Gaussian distribution
mu1 = 0.7 
mu2 = 1.6 
sigma1 = 1.8 
sigma2 = 1.7 
rho = 0.529 
y1 = -0.4 
y2 = 2.2 

# generate data points from the 2D Gaussian distribution
mu = jnp.array([mu1, mu2])
K = jnp.array([[sigma1**2, rho*sigma1*sigma2],[rho*sigma2*sigma1, sigma2**2]])
num_samples = 1000
data = np.random.multivariate_normal(mu, K, num_samples)

# calculate marginal distributions
y_values = jnp.linspace(-8, 8, 300)

normal1 = dist.Normal(loc=mu1, scale=sigma1)
normal2 = dist.Normal(loc=mu2, scale=sigma2)
density_x1 = jnp.exp(normal1.log_prob(y_values))
density_x2 = jnp.exp(normal2.log_prob(y_values))

# compute conditionals
cond_mu_x1, cond_sigma_x1 = Gaussian_conditional(mu, K, x=None, y=y2)
cond_mu_x2, cond_sigma_x2 = Gaussian_conditional(mu, K, x=y1, y=None)
cond_density_x1 = jnp.exp(dist.Normal(loc=cond_mu_x1, scale=cond_sigma_x1).log_prob(y_values))
cond_density_x2 = jnp.exp(dist.Normal(loc=cond_mu_x2, scale=cond_sigma_x2).log_prob(y_values))


fig = plt.figure(figsize=(10, 10))
gs = fig.add_gridspec(3, 3)

# main plot (2D Gaussian distribution)
ax_main = fig.add_subplot(gs[1:3, :2])
ax_main.scatter(data[:, 0], data[:, 1], alpha=0.5, label='2d Gaussian')
ax_main.set_xlabel('$x_1$')
ax_main.set_ylabel('$x_2$')
ax_main.axvline(y1, lw=2, c='C2', linestyle = '--')
ax_main.axhline(y2, lw=2, c='C1', linestyle = '--')
ax_main.legend()
ax_main.grid(True)

# marginal x1 plot
ax_marginal_x = fig.add_subplot(gs[0, :2], sharex=ax_main)
ax_marginal_x.plot(y_values, density_x1, label='Marginal $y_1$')
ax_marginal_x.plot(y_values, cond_density_x1, label='Conditional $y_1$', linestyle = '--', c='C1')
ax_marginal_x.legend()
ax_marginal_x.grid(True)

# marginal x2 plot
ax_marginal_y = fig.add_subplot(gs[1:3, 2], sharey=ax_main)
ax_marginal_y.plot(density_x2, y_values, label='Marginal $Y_2$' )
ax_marginal_y.plot(cond_density_x2, y_values, label='Conditional $Y_2$', linestyle='--', c='C2')
ax_marginal_y.legend()
ax_marginal_y.grid(True)

plt.tight_layout()
plt.show()
_images/1045a1f1b8ef8aafa0169ecdd51e52a608c4fc93812d1d245f9ff68a97f05f20.png

Inference in GLMMs#

In a typical setting, GP enters GLMMs as a latent variable and the model has the form

\[\begin{split} \begin{align*} f(x) &\sim \mathcal{GP}(\mu(x), k(x, x')), \\ y | f &\sim \Pi_i p(y_i \vert f(x_i)). \end{align*} \end{split}\]

Here \(\{(x_i, y_i)\}_{i=1}^n\) are pairs of observations \(y_i\), and locations of those observations \(x_i\). The role of \(f(x)\) now is to serve as a latent field capturing dependencies between locations \(x\). The expression \(\Pi_i p(y_i \vert f(x_i))\) provides a likelihood, allowing us to link observed data to the model and enabling parameter inference.

The task that we usually want to solve is twofold: we want to infer parameters involed in the model, as well as make predictions at unobserved locations \(x_*.\)

Gaussian process regression#

The simplest case of the setting described above is the Gaussian process regression where the outcome variable is modelled as a GP with added noise \(\epsilon\). It assumes that the data consists of pairs \(\{(x_i, y_i)\}_{i=1}^n\) and the likelihood is Gaussian with variance \(\sigma^2_\epsilon:\)

\[\begin{split} \begin{align*} f(x) &\sim \mathcal{GP}(\mu(x), k(x, x')),\\ y_i &= f(x_i) + \epsilon_i, \\ \epsilon_i &\sim \mathcal{N}(0, \sigma^2_\epsilon). \end{align*} \end{split}\]

We want to obtain estimates at locations \(\{(x_i, y_i)\}_{i=1}^m\). Recall the conditioning formula: if

\[\begin{split} \begin{align*} \begin{bmatrix} f_n \\ f_m \end{bmatrix} \sim \mathcal{N}\left(\begin{bmatrix} 0 \\ 0\end{bmatrix}, \\ \begin{bmatrix} K_{n n} \quad K_{n m} \\ K_{m n} \quad K_{m m} \end{bmatrix} \right), \end{align*} \end{split}\]

then

\[\begin{split} \begin{align*} f_m | f_n &\sim \mathcal{N}(\mu_{m|n}, K_{m|n}),\\ \mu_{m|n} & = K_{mn} K_{nn}^{-1} f_n,\\ K_{m|n} & = K_{mm} - K_{mn} K_{nn}^{-1} K_{nm}. \end{align*} \end{split}\]

There is an issue with this formula since due to the noise we can not observe \(f_n\), but we observe \(y_n\) instead. The conditional in this case reads as

\[\begin{split} \begin{align*} f_m | y_n &\sim \mathcal{N}(\mu_{m|n}, K_{m|n}),\\ \mu_{m|n} & = K_{mn} (K_{nn} + \sigma^2_\epsilon I )^{-1} y_n,\\ K_{m|n} & = K_{mm} - K_{mn} (K_{nn} + \sigma^2_\epsilon I)^{-1} K_{nm}. \end{align*} \end{split}\]

This is the predictive distribution of the Gaussian process.

Putting the two distributions side-by-side:

  • posterior distribution in the noise-free case

\[ f_m | f_n \sim \mathcal{N}( K_{mn} K_{nn}^{-1} f_n, K_{mm} - K_{mn} K_{nn} ^{-1} K_{nm}) \]
  • posterior distribution in the noisy case

\[ f_m | y_n \sim \mathcal{N}(K_{mn} (K_{nn} + I\sigma^2_\epsilon )^{-1} y_n, K_{mm} - K_{mn} (K_{nn} + I\sigma^2_\epsilon )^{-1} K_{nm}) \]

Parameters \(\theta\) of the model can be learnt using the marginal log-likelihood which for the latter posterior takes the form

\[ \log p(y|\theta)= -\frac{n}{2} \log(2\pi) - \frac{1}{2}\log \vert K_{nn} + \sigma^2_\epsilon I \vert - \frac{1}{2}y^T(K_{nn} + \sigma^2_\epsilon I)^{-1}y \]

Computational complexity#

Note that the predictive distribution involes inversion of a \(n \times n\) matrix. This computation has cubic complexity \(O(n^3)\) with respect to the number of points \(n\) and creates a computational bottleneck when dealing with GP inference.

Let us implement the predictive distribution.

def rbf_kernel(x1, x2, lengthscale=1.0, sigma=1.0):
    """
    compute the Radial Basis Function (RBF) kernel matrix between two sets of points

    args:
    - x1 (array): array of shape (n1, d) representing the first set of points
    - x2 (array): array of shape (n2, d) representing the second set of points
    - sigma (float): variance parameter
    - length_scale (float): length-scale parameter
    - jitter (float): small positive value added to the diagonal elementsr

    returns:
    - K (array): kernel matrix of shape (n1, n2)
    """
    sq_dist = jnp.sum(x1**2, axis=1).reshape(-1, 1) + jnp.sum(x2**2, axis=1) - 2 * jnp.dot(x1, x2.T)
    K = sigma**2 * jnp.exp(-0.5 / lengthscale**2 * sq_dist)
    return K
def plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true=False):
    """
    plots the Gaussian process predictive distribution

    args:
    - x_obs: training inputs, shape (n_train_samples, n_features)
    - y_train: training targets, shape (n_train_samples,)
    - x_pred: test input points, shape (n_test_samples, n_features)
    - mean: mean of the predictive distribution, shape (n_test_samples,)
    - variance: variance of the predictive distribution, shape (n_test_samples,)
    """
    plt.figure(figsize=(8, 6))
    if not f_true is False:
        plt.plot(x_pred, f_true, label='True Function', color='purple')
    plt.scatter(x_obs, y_obs, c='orangered', label='Training Data')
    plt.plot(x_pred, mean, label='Mean Prediction', color='teal')
    plt.fill_between(x_pred.squeeze(), mean - jnp.sqrt(variance), mean + jnp.sqrt(variance), color='teal', alpha=0.3, label='Uncertainty')
    plt.title('Gaussian Process Predictive Distribution')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.show()
def predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=1.0, sigma=1.0, jitter=1e-8):
    """
    predicts the mean and variance of the Gaussian process at test points

    args:
    - x_obs: training inputs, shape (n_train_samples, n_features)
    - y_obs: training targets, shape (n_train_samples,)
    - x_pred: test input points, shape (n_test_samples, n_features)
    - length_scale: length-scale parameter
    - variance: variance parameter
    - jitter: jitter to ensure computational stability

    returns:
    - mean: Mean of the predictive distribution, shape (n_test_samples,).
    - variance: Variance of the predictive distribution, shape (n_test_samples,).
    """
    K = rbf_kernel(x_obs, x_obs, length_scale, sigma) + jitter * jnp.eye(len(x_obs))
    K_inv = jnp.linalg.inv(K)
    K_star = rbf_kernel(x_obs, x_pred, length_scale, sigma)
    K_star_star = rbf_kernel(x_pred, x_pred, length_scale, sigma)

    mean = jnp.dot(K_star.T, jnp.dot(K_inv, y_obs))
    variance = jnp.diag(K_star_star) - jnp.sum(jnp.dot(K_star.T, K_inv) * K_star.T, axis=1)

    return mean, variance
# example 
x_obs = jnp.array([[1], [2], [3]])               # training inputs
y_obs = jnp.array([3, 4, 2])                     # training targets
x_pred = jnp.linspace(0, 5, 100).reshape(-1, 1)  # test inputs

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred)
plot_gp(x_obs, y_obs, x_pred, mean, variance)
_images/99146bae96a3b93f29bab11b18b620e93829d8e8253d96ac95b0a65531502c5c.png

This seem to work well. Let’s take a look at some noisy data.

#  the true pattern
def f(x):
    return (6 * x/5 - 2)**2 * jnp.sin(12 * x/5 - 4)

# generate true function values
x = jnp.linspace(0, 5, 100)
f_true = f(x)

# sample noise with a normal distribution
noise = 2*jax.random.normal(jax.random.PRNGKey(0), shape=f_true.shape)

# add noise to the true function values to obtain the observed values (y)
y = f_true + noise

# indices of observed data points
idx_obs = jnp.array([[10], [50], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, color='orangered', label='Training Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fce7c197410>
_images/922ab04b5f441abfb0a19cd846c125e6c2a52f512bdea02ab9bf338bbb25a91b.png
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/c6feab481286f07caaf676a4216993849847adb131a7c00b66f18e384a3a7f43.png

Let’s add more points and repeat the procedure.

# indices of observed data points
idx_obs = jnp.array([[10], [12],[20], [68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/e476fc15ea16999c68277797ba774502e5f0fbad0ac92f3da9bdd543527aeb34.png
# Indices of observed data points
idx_obs = jnp.array([[5], [10], [12],[20], [25], [41], [55],[68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/76bb1e387715974a1f6f797760ab94b6e9ef5d1e282c1a5c46c25f0407282a52.png

As the number of observed data points increases, we observe two phenomena. Fisrtly, the unertainty bounds are shrinking (as expected); however, rather that trying to generalise, our models starts passing through every point.

Non-Gaussian likelihoods#

We have seen that in the Normal-Normal setting, posterior distribution was accessible analytically.

For non-conjugate likelihood models this is not the case. When the likelihood is non-Gaussian, the resulting posterior distribution is no longer Gaussian, and obtaining a closed-form expression for the predictive distribution may be challenging or impossible.

In such cases, computational methods such as Markov Chain Monte Carlo or Variational Inference are often used to approximate the predictive distribution. We will use Numpyro and its MCMC engine for this purpose.

Draw GP priors using Numpyro functionality#

In the previous chapter we saw how to draw GP priors numerically. While we don’t necessarily need Numpyro to do inference in the case of Gaussian likelihood, we will soon see that we will need Numpyro for inference in other cases.

Let us start building that code by drawing GP priors using Numpyro functionality.

def plot_gp_samples(x, samples, ttl="", num_samples=10):
    
    plt.figure(figsize=(6, 4))
    for i in range(num_samples):
        plt.plot(x, samples[i], label=f'Sample {i}')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title(ttl)
    plt.legend()
    plt.tight_layout()
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
    plt.show()
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5, noise=0.5):    
    """
    Gaussian Process prior with a Numpyro model

    args:
    - x (jax.numpy.ndarray): input data points of shape (n, d), where n is the number of points and d is the number of dimensions
    - kernel_func (function): kernel function to use
    - lengthscale (float): length-scale parameter
    - jitter (float): jitter for numerical stability

    returns:
    - numpyro.sample: a sample from the Multivariate Normal distribution representing the noisy function values at input points
    """

    n = x.shape[0]

    K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)

    f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
    
    numpyro.sample("y", dist.Normal(f, noise), obs=y)
n_points = 50
num_samples = 1000

# input locations
x = jnp.linspace(0, 1, n_points).reshape(-1, 1)

mcmc_predictive = Predictive(model, num_samples=num_samples)
samples = mcmc_predictive(jax.random.PRNGKey(0), x=x)
f_samples = samples['f']

# calculate mean and standard deviation
mean = jnp.mean(f_samples, axis=0)
std = jnp.std(f_samples, axis=0)


plt.figure(figsize=(8, 6))
for i in range(10):
    plt.plot(x, f_samples[i], color='teal', lw=0.3)
plt.plot(x, mean, 'b-', label='Mean', color='teal')  # point-wise mean of all samples
plt.fill_between(x.squeeze(), mean - 1.96 * std, mean + 1.96 * std, color='teal', alpha=0.3, label='Uncertainty Bounds')  # uncertainty bounds
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.title('GP samples obtained using Numpyro')    
plt.grid()
plt.legend()
plt.show()
/tmp/ipykernel_2698/2836025616.py:19: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "b-" (-> color='b'). The keyword argument will take precedence.
  plt.plot(x, mean, 'b-', label='Mean', color='teal')  # point-wise mean of all samples
_images/3ad3cd1f31cb6554b3bd2fb501f98c39668076c2209662be01867fd7364a0ada.png

Inference with Numpyro: known noise#

For inference, we will use the simulated data in the previous step as input.

true_idx = 3

y_samples = samples['y']
f_true = f_samples[true_idx]
y_obs =  y_samples[true_idx]

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fce543f9350>
_images/6eb3803e8fd1db98dfe7b959dd9727fbcace40119b79a44d6e1ad2da9a3073b3.png

Let us see whether we can recover \(f\) from observed \(y\) values. At this stage we assume that the amount of noise is known.

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# print summary statistics of posterior
mcmc.print_summary()

# get the posterior samples
posterior_samples = mcmc.get_samples()
f_posterior = posterior_samples['f']
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      f[0]      0.98      0.26      0.97      0.56      1.42   2120.13      1.00
      f[1]      1.10      0.22      1.09      0.73      1.47   1894.31      1.00
      f[2]      1.23      0.20      1.23      0.90      1.54   1740.22      1.00
      f[3]      1.38      0.18      1.38      1.09      1.67   1692.13      1.00
      f[4]      1.54      0.17      1.54      1.26      1.81   1698.71      1.00
      f[5]      1.70      0.16      1.71      1.43      1.97   1697.40      1.00
      f[6]      1.88      0.16      1.88      1.61      2.13   1766.68      1.00
      f[7]      2.05      0.16      2.05      1.78      2.31   1743.53      1.00
      f[8]      2.21      0.16      2.21      1.95      2.47   1721.54      1.00
      f[9]      2.37      0.16      2.37      2.11      2.63   1690.00      1.00
     f[10]      2.51      0.16      2.50      2.25      2.76   1657.94      1.00
     f[11]      2.62      0.16      2.62      2.36      2.87   1632.71      1.00
     f[12]      2.70      0.16      2.70      2.44      2.95   1617.95      1.00
     f[13]      2.75      0.16      2.75      2.48      3.00   1609.96      1.00
     f[14]      2.76      0.16      2.76      2.51      3.04   1606.50      1.00
     f[15]      2.73      0.16      2.73      2.47      3.00   1608.51      1.00
     f[16]      2.65      0.16      2.65      2.40      2.93   1609.70      1.00
     f[17]      2.54      0.16      2.54      2.29      2.81   1608.39      1.00
     f[18]      2.38      0.16      2.38      2.14      2.66   1606.05      1.00
     f[19]      2.19      0.16      2.19      1.93      2.45   1600.37      1.00
     f[20]      1.98      0.16      1.97      1.71      2.23   1590.41      1.00
     f[21]      1.73      0.16      1.73      1.47      1.99   1577.27      1.00
     f[22]      1.47      0.16      1.48      1.22      1.73   1565.24      1.00
     f[23]      1.21      0.16      1.21      0.95      1.46   1541.40      1.00
     f[24]      0.94      0.16      0.94      0.68      1.20   1479.06      1.00
     f[25]      0.68      0.16      0.68      0.42      0.94   1424.89      1.00
     f[26]      0.43      0.16      0.43      0.14      0.67   1393.53      1.00
     f[27]      0.20      0.16      0.20     -0.08      0.44   1391.66      1.00
     f[28]     -0.02      0.16     -0.02     -0.28      0.23   1411.48      1.00
     f[29]     -0.22      0.16     -0.22     -0.48      0.04   1447.38      1.00
     f[30]     -0.39      0.16     -0.39     -0.65     -0.14   1472.72      1.00
     f[31]     -0.55      0.16     -0.55     -0.81     -0.30   1487.86      1.00
     f[32]     -0.68      0.16     -0.68     -0.94     -0.43   1493.16      1.00
     f[33]     -0.81      0.16     -0.81     -1.07     -0.56   1497.70      1.00
     f[34]     -0.91      0.16     -0.91     -1.16     -0.64   1501.22      1.00
     f[35]     -1.01      0.16     -1.01     -1.26     -0.75   1500.57      1.00
     f[36]     -1.09      0.16     -1.09     -1.36     -0.84   1496.60      1.00
     f[37]     -1.16      0.16     -1.16     -1.44     -0.91   1498.29      1.00
     f[38]     -1.22      0.16     -1.22     -1.49     -0.96   1507.57      1.00
     f[39]     -1.27      0.16     -1.27     -1.53     -1.00   1522.21      1.00
     f[40]     -1.31      0.16     -1.31     -1.57     -1.03   1543.47      1.00
     f[41]     -1.33      0.16     -1.33     -1.61     -1.07   1565.48      1.00
     f[42]     -1.33      0.16     -1.33     -1.60     -1.07   1593.17      1.00
     f[43]     -1.31      0.16     -1.31     -1.57     -1.04   1618.24      1.00
     f[44]     -1.27      0.16     -1.27     -1.53     -1.00   1653.18      1.00
     f[45]     -1.21      0.17     -1.21     -1.48     -0.94   1696.87      1.00
     f[46]     -1.13      0.18     -1.13     -1.43     -0.85   1769.31      1.00
     f[47]     -1.03      0.20     -1.02     -1.35     -0.71   1903.47      1.00
     f[48]     -0.92      0.22     -0.91     -1.29     -0.56   2086.05      1.00
     f[49]     -0.79      0.26     -0.79     -1.24     -0.39   2301.05      1.00

Number of divergences: 0
# calculate mean and standard deviation
f_mean = jnp.mean(f_posterior, axis=0)
f_hpdi = hpdi(f_posterior, 0.95)

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.plot(x, f_mean, color='teal', label='Mean Prediction')
plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='teal', alpha=0.3, label='HPDI')  # uncertainty bounds
plt.grid()
plt.legend()
plt.show()
_images/70584076f3130b239ffcdc9aa7c9b9bb64b06d6b82d9465b19c928cb17e38caa.png

This looks good. But in reality we barely ever know true level of noise.

Inference with Numpyro: estimating noise#

def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5):    
    """
    args:
    - x (jax.numpy.ndarray): input data points of shape (n, d), where n is the number of points and d is the number of dimensions
    - kernel_func (function): kernel function 
    - lengthscale (float): length-scale parameter 
    - jitter (float): jitter for numerical stability

    returns:
    - numpyro.sample: a sample from the Multivariate Normal distribution representing the function values at input points
    """

    n = x.shape[0]

    K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)

    f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))

    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
    
    numpyro.sample("y", dist.Normal(f, sigma), obs=y)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()

# Get the posterior samples
posterior_samples = mcmc.get_samples()
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[21], line 2
      1 # Print summary statistics of posterior
----> 2 mcmc.print_summary()
      4 # Get the posterior samples
      5 posterior_samples = mcmc.get_samples()

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:783, in MCMC.print_summary(self, prob, exclude_deterministic)
    777     if isinstance(state_sample_field, dict):
    778         sites = {
    779             k: v
    780             for k, v in self._states[self._sample_field].items()
    781             if k in state_sample_field
    782         }
--> 783 print_summary(sites, prob=prob)
    784 extra_fields = self.get_extra_fields()
    785 if "diverging" in extra_fields:

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:312, in print_summary(samples, prob, group_by_chain)
    308 if not isinstance(samples, dict):
    309     samples = {
    310         "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
    311     }
--> 312 summary_dict = summary(samples, prob, group_by_chain=True)
    313 if not summary_dict:
    314     return

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:263, in summary(samples, prob, group_by_chain)
    261 if len(value) == 0:
    262     continue
--> 263 value = device_get(value)
    264 value_flat = np.reshape(value, (-1,) + value.shape[2:])
    265 mean = value_flat.mean(axis=0)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/api.py:2494, in device_get(x)
   2492   except AttributeError:
   2493     pass
-> 2494 return tree_map(_device_get, x)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:359, in tree_map(f, tree, is_leaf, *rest)
    357 leaves, treedef = tree_flatten(tree, is_leaf)
    358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:359, in <genexpr>(.0)
    357 leaves, treedef = tree_flatten(tree, is_leaf)
    358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/api.py:2455, in _device_get(x)
   2453   return x
   2454 else:
-> 2455   return toarray()

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/array.py:430, in ArrayImpl.__array__(self, dtype, context, copy)
    427 def __array__(self, dtype=None, context=None, copy=None):
    428   # copy argument is supported by np.asarray starting in numpy 2.0
    429   kwds = {} if copy is None else {'copy': copy}
--> 430   return np.asarray(self._value, dtype=dtype, **kwds)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    330 @wraps(func)
    331 def wrapper(*args, **kwargs):
    332   with TraceAnnotation(name, **decorator_kwargs):
--> 333     return func(*args, **kwargs)
    334   return wrapper

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/array.py:650, in ArrayImpl._value(self)
    648 npy_value = np.empty(self.shape, self.dtype)
    649 for i, ind in _cached_index_calc(self.sharding, self.shape):
--> 650   npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
    651 self._npy_value = npy_value
    652 self._npy_value.flags.writeable = False

KeyboardInterrupt: 
sigma_posterior = posterior_samples['sigma']

plt.figure(figsize=(6, 4))
sns.kdeplot(sigma_posterior, fill=True)
plt.axvline(x=0.5, color='r', linestyle='--', label='True value')
plt.xlabel('$\sigma$')
plt.ylabel('Density')
plt.title('Posterior of parameter $\sigma$')

# Show the plot
plt.show()
_images/7b73a2c0ec12ae65437d429c708529823689ff2b15fe7578b2ab979b28a2475c.png

We have inferred the variance parameter successfully. Estimating lengthscale, especially for less smooth kernels is harder. One more issue is the non-identifiability of the pair lengthscale-variance. Hence, if they both really need to be inferred, strong priors would be beneficial.