Gaussian processes: inference#
In the previous chapter we have seen how to generate Gaussian process priors. But what we usually would want to do is inference, i.e. given observed data, we would like to estimate model parameters.
In a typical setting, GP enters GLMMs as a latent variable and the model has the form
Here \(\{(x_i, y_i)\}_{i=1}^n\) are pairs of observations \(y_i\), and locations of those observations \(x_i\). The role of \(f(x)\) now is to serve as a latent field capturing dependencies between locations \(x\). The expression \(\Pi_i p(y_i \vert f(x_i))\) provides a likelihood, allowing us to link observed data to the model and enabling parameter inference.
The task that we usually want to solve is twofold: we want to infer parameters involed in the model, as well as make predictions at unobserved locations \(x_*.\)
Gaussian process regression#
The simplest case of the setting described above is the Gaussian process regression where the outcome variable is modelled as a GP with added noise \(\epsilon\). It assumes that the data consists of pairs \(\{(x_i, y_i)\}_{i=1}^n\) and the likelihood is Gaussian with variance \(\sigma^2_\epsilon:\)
We want to obtain estimates at locations \(\{(x_i, y_i)\}_{i=1}^m\). Recall the conditioning formula: if
then
There is an issue with this formula since due to the noise we can not observe \(f_n\), but we observe \(y_n\) instead. The conditional in this case reads as
This is the predictive distribution of the Gaussian process.
Putting the two distributions side-by-side:
posterior distribution in the noise-free case
posterior distribution in the noisy case
Parameters \(\theta\) of the model can be learnt using the marginal log-likelihood which for the latter posterior takes the form
Computational complexity#
Note that the predictive distribution involes inversion of a \(n \times n\) matrix. This computation has cubic complexity \(O(n^3)\) with respect to the number of points \(n\) and creates a computational bottleneck when dealing with GP inference.
# imports for this chapter
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpyro
import numpyro.distributions as dist
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
from numpyro.diagnostics import hpdi
import pickle
numpyro.set_host_device_count(4) # Set the device count to enable parallel sampling
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Let us implement the predictive distribution.
def rbf_kernel(x1, x2, lengthscale=1.0, sigma=1.0):
"""
Compute the Radial Basis Function (RBF) kernel matrix between two sets of points.
Args:
- x1 (array): Array of shape (n1, d) representing the first set of points.
- x2 (array): Array of shape (n2, d) representing the second set of points.
- sigma (float): Variance parameter.
- length_scale (float): Length-scale parameter.
- jitter (float): Small positive value added to the diagonal elements.
Returns:
- K (array): Kernel matrix of shape (n1, n2).
"""
sq_dist = jnp.sum(x1**2, axis=1).reshape(-1, 1) + jnp.sum(x2**2, axis=1) - 2 * jnp.dot(x1, x2.T)
K = sigma**2 * jnp.exp(-0.5 / lengthscale**2 * sq_dist)
return K
def plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true=False):
"""
Plots the Gaussian process predictive distribution.
Args:
- x_obs: Training inputs, shape (n_train_samples, n_features).
- y_train: Training targets, shape (n_train_samples,).
- x_pred: Test input points, shape (n_test_samples, n_features).
- mean: Mean of the predictive distribution, shape (n_test_samples,).
- variance: Variance of the predictive distribution, shape (n_test_samples,).
"""
plt.figure(figsize=(8, 6))
if not f_true is False:
plt.plot(x_pred, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, c='orangered', label='Training Data')
plt.plot(x_pred, mean, label='Mean Prediction', color='teal')
plt.fill_between(x_pred.squeeze(), mean - jnp.sqrt(variance), mean + jnp.sqrt(variance), color='teal', alpha=0.3, label='Uncertainty')
plt.title('Gaussian Process Predictive Distribution')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid()
plt.show()
def predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=1.0, sigma=1.0, jitter=1e-8):
"""
Predicts the mean and variance of the Gaussian process at test points.
Args:
- x_obs: Training inputs, shape (n_train_samples, n_features).
- y_obs: Training targets, shape (n_train_samples,).
- x_pred: Test input points, shape (n_test_samples, n_features).
- length_scale: Length scale parameter.
- variance: Variance parameter.
- jitter: Jitter to ensure computational stability.
Returns:
- mean: Mean of the predictive distribution, shape (n_test_samples,).
- variance: Variance of the predictive distribution, shape (n_test_samples,).
"""
K = rbf_kernel(x_obs, x_obs, length_scale, sigma) + jitter * jnp.eye(len(x_obs))
K_inv = jnp.linalg.inv(K)
K_star = rbf_kernel(x_obs, x_pred, length_scale, sigma)
K_star_star = rbf_kernel(x_pred, x_pred, length_scale, sigma)
mean = jnp.dot(K_star.T, jnp.dot(K_inv, y_obs))
variance = jnp.diag(K_star_star) - jnp.sum(jnp.dot(K_star.T, K_inv) * K_star.T, axis=1)
return mean, variance
# Example usage:
x_obs = jnp.array([[1], [2], [3]]) # Training inputs
y_obs = jnp.array([3, 4, 2]) # Training targets
x_pred = jnp.linspace(0, 5, 100).reshape(-1, 1) # Test input points
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred)
plot_gp(x_obs, y_obs, x_pred, mean, variance)
This seem to work well. Let’s take a look at some noisy data.
# Define the true pattern
def f(x):
return (6 * x/5 - 2)**2 * jnp.sin(12 * x/5 - 4)
# Define the range of x values
x = jnp.linspace(0, 5, 100)
# Generate the true function values
f_true = f(x)
# Generate noise with a normal distribution
noise = 2*jax.random.normal(jax.random.PRNGKey(0), shape=f_true.shape)
# Add noise to the true function values to obtain the observed values (y)
y = f_true + noise
# Indices of observed data points
idx_obs = jnp.array([[10], [50], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)
plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, color='orangered', label='Training Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fb52c4af810>
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
Let’s add more points and repeat the procedure.
# Indices of observed data points
idx_obs = jnp.array([[10], [12],[20], [68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
# Indices of observed data points
idx_obs = jnp.array([[5], [10], [12],[20], [25], [41], [55],[68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
As the number of observed data points increases, we observe two phenomena. Fisrtly, the unertainty bounds are shrinking (as expected); however, rather that trying to generalise, our models starts passing through every point.
Non-Gaussian likelihoods#
We have seen that in the Normal-Normal setting, posterior distribution was accessible analytically.
For non-conjugate likelihood models this is not the case. When the likelihood is non-Gaussian, the resulting posterior distribution is no longer Gaussian, and obtaining a closed-form expression for the predictive distribution may be challenging or impossible.
In such cases, computational methods such as Markov Chain Monte Carlo or Variational Inference are often used to approximate the predictive distribution. We will use Numpyro and its MCMC engine for this purpose.
Draw GP priors using Numpyro functionality#
In the previous chapter we saw how to draw GP priors numerically. While we don’t necessarily need Numpyro to do inference in the case of Gaussian likelihood, we will soon see that we will need Numpyro for inference in other cases.
Let us start building that code by drawing GP priors using Numpyro functionality.
def plot_gp_samples(x, samples, ttl="", num_samples=10):
# Plot the samples
plt.figure(figsize=(6, 4))
for i in range(num_samples):
plt.plot(x, samples[i], label=f'Sample {i}')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title(ttl)
plt.legend()
plt.tight_layout()
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
plt.show()
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5, noise=0.5):
"""
Gaussian Process prior with a Numpyro model.
Args:
- x (jax.numpy.ndarray): Input data points of shape (n, d), where n is the number of points and d is the number of dimensions.
- kernel_func (function): Kernel function to use. Default is rbf_kernel.
- lengthscale (float): Length scale parameter for the kernel function. Default is 0.2.
- jitter (float): Small constant added to the diagonal of the kernel matrix for numerical stability. Default is 1e-5.
Returns:
- numpyro.sample: A sample from the Multivariate Normal distribution representing the function values at input points.
"""
n = x.shape[0]
K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)
f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
numpyro.sample("y", dist.Normal(f, noise), obs=y)
# Define parameters
n_points = 50
num_samples = 1000
# Generate random input data
x = jnp.linspace(0, 1, n_points).reshape(-1, 1)
mcmc_predictive = Predictive(model, num_samples=num_samples)
samples = mcmc_predictive(jax.random.PRNGKey(0), x=x)
f_samples = samples['f']
# Calculate mean and standard deviation
mean = jnp.mean(f_samples, axis=0)
std = jnp.std(f_samples, axis=0)
# Visualization
plt.figure(figsize=(8, 6))
for i in range(10):
plt.plot(x, f_samples[i], color='teal', lw=0.3)
plt.plot(x, mean, 'b-', label='Mean', color='teal') # Point-wise mean of all samples
plt.fill_between(x.squeeze(), mean - 1.96 * std, mean + 1.96 * std, color='teal', alpha=0.3, label='Uncertainty Bounds') # Uncertainty bounds
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.title('GP samples obtained using Numpyro')
plt.grid()
plt.legend()
plt.show()
/tmp/ipykernel_2817/1982372603.py:20: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "b-" (-> color='b'). The keyword argument will take precedence.
plt.plot(x, mean, 'b-', label='Mean', color='teal') # Point-wise mean of all samples
Inference with Numpyro: known noise#
For inference, we will use the simulated data in the previous step as input.
true_idx = 3
y_samples = samples['y']
f_true = f_samples[true_idx]
y_obs = y_samples[true_idx]
plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fb528711550>
Let us see whether we can recover \(f\) from observed \(y\) values. At this stage we assume that the amount of noise is known.
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()
# Get the posterior samples
posterior_samples = mcmc.get_samples()
f_posterior = posterior_samples['f']
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[15], line 2
1 # Print summary statistics of posterior
----> 2 mcmc.print_summary()
4 # Get the posterior samples
5 posterior_samples = mcmc.get_samples()
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:763, in MCMC.print_summary(self, prob, exclude_deterministic)
757 if isinstance(state_sample_field, dict):
758 sites = {
759 k: v
760 for k, v in self._states[self._sample_field].items()
761 if k in state_sample_field
762 }
--> 763 print_summary(sites, prob=prob)
764 extra_fields = self.get_extra_fields()
765 if "diverging" in extra_fields:
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:296, in print_summary(samples, prob, group_by_chain)
292 if not isinstance(samples, dict):
293 samples = {
294 "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
295 }
--> 296 summary_dict = summary(samples, prob, group_by_chain=True)
298 row_names = {
299 k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]"
300 for k, v in samples.items()
301 }
302 max_len = max(max(map(lambda x: len(x), row_names.values())), 10)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:249, in summary(samples, prob, group_by_chain)
247 summary_dict = {}
248 for name, value in samples.items():
--> 249 value = device_get(value)
250 value_flat = np.reshape(value, (-1,) + value.shape[2:])
251 mean = value_flat.mean(axis=0)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/api.py:2484, in device_get(x)
2482 except AttributeError:
2483 pass
-> 2484 return tree_map(_device_get, x)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:344, in tree_map(f, tree, is_leaf, *rest)
342 leaves, treedef = tree_flatten(tree, is_leaf)
343 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 344 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:344, in <genexpr>(.0)
342 leaves, treedef = tree_flatten(tree, is_leaf)
343 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 344 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/api.py:2445, in _device_get(x)
2443 return x
2444 else:
-> 2445 return toarray()
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/array.py:425, in ArrayImpl.__array__(self, dtype, context, copy)
422 def __array__(self, dtype=None, context=None, copy=None):
423 # copy argument is supported by np.asarray starting in numpy 2.0
424 kwds = {} if copy is None else {'copy': copy}
--> 425 return np.asarray(self._value, dtype=dtype, **kwds)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
330 @wraps(func)
331 def wrapper(*args, **kwargs):
332 with TraceAnnotation(name, **decorator_kwargs):
--> 333 return func(*args, **kwargs)
334 return wrapper
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/array.py:645, in ArrayImpl._value(self)
643 npy_value = np.empty(self.shape, self.dtype)
644 for i, ind in _cached_index_calc(self.sharding, self.shape):
--> 645 npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
646 self._npy_value = npy_value
647 self._npy_value.flags.writeable = False
KeyboardInterrupt:
# Calculate mean and standard deviation
f_mean = jnp.mean(f_posterior, axis=0)
f_hpdi = hpdi(f_posterior, 0.95)
#plt.plot(x, f_true, color='orange', label='f(x)')
#plt.scatter(x, y_obs, color='red', label='y')
#plt.plot(x, f_mean)
#plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='lightblue', alpha=0.3, label='HPDI') # Uncertainty bounds
#plt.legend()
#plt.show()
plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.plot(x, f_mean, color='teal', label='Mean Prediction')
plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='teal', alpha=0.3, label='HPDI') # Uncertainty bounds
plt.grid()
plt.legend()
plt.show()
This looks good. But in reality we barely ever know true level of noise.
Inference with Numpyro: estimating noise#
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5):
"""
Gaussian Process prior with a Numpyro model.
Args:
- x (jax.numpy.ndarray): Input data points of shape (n, d), where n is the number of points and d is the number of dimensions.
- kernel_func (function): Kernel function to use. Default is rbf_kernel.
- lengthscale (float): Length scale parameter for the kernel function. Default is 0.2.
- jitter (float): Small constant added to the diagonal of the kernel matrix for numerical stability. Default is 1e-5.
Returns:
- numpyro.sample: A sample from the Multivariate Normal distribution representing the function values at input points.
"""
n = x.shape[0]
K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)
f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
numpyro.sample("y", dist.Normal(f, sigma), obs=y)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()
# Get the posterior samples
posterior_samples = mcmc.get_samples()
mean std median 5.0% 95.0% n_eff r_hat
f[0] 0.97 0.26 0.97 0.54 1.38 2375.59 1.00
f[1] 1.09 0.22 1.09 0.74 1.46 2303.95 1.00
f[2] 1.23 0.19 1.22 0.90 1.53 2235.81 1.00
f[3] 1.37 0.18 1.37 1.08 1.65 2149.90 1.00
f[4] 1.53 0.17 1.53 1.26 1.80 2043.35 1.00
f[5] 1.70 0.16 1.70 1.42 1.95 1926.71 1.00
f[6] 1.87 0.16 1.87 1.59 2.13 1803.14 1.00
f[7] 2.05 0.16 2.05 1.79 2.31 1696.37 1.00
f[8] 2.22 0.16 2.22 1.94 2.47 1624.57 1.00
f[9] 2.37 0.16 2.37 2.10 2.63 1588.26 1.00
f[10] 2.51 0.16 2.51 2.25 2.78 1585.35 1.00
f[11] 2.63 0.16 2.63 2.36 2.89 1611.61 1.00
f[12] 2.71 0.16 2.71 2.44 2.97 1662.09 1.00
f[13] 2.76 0.16 2.76 2.51 3.04 1710.88 1.00
f[14] 2.77 0.16 2.77 2.51 3.03 1751.52 1.00
f[15] 2.74 0.15 2.74 2.49 3.00 1763.07 1.00
f[16] 2.66 0.15 2.66 2.40 2.91 1751.03 1.00
f[17] 2.54 0.15 2.55 2.30 2.80 1728.46 1.00
f[18] 2.39 0.15 2.39 2.14 2.64 1702.39 1.00
f[19] 2.20 0.15 2.20 1.96 2.45 1683.25 1.00
f[20] 1.98 0.15 1.98 1.74 2.23 1670.92 1.00
f[21] 1.73 0.15 1.73 1.48 1.98 1683.03 1.00
f[22] 1.47 0.15 1.47 1.22 1.71 1712.11 1.00
f[23] 1.21 0.15 1.21 0.96 1.46 1747.18 1.00
f[24] 0.94 0.15 0.94 0.70 1.20 1767.20 1.00
f[25] 0.68 0.15 0.68 0.44 0.93 1783.08 1.00
f[26] 0.43 0.15 0.43 0.18 0.68 1796.17 1.00
f[27] 0.20 0.15 0.20 -0.05 0.45 1809.50 1.00
f[28] -0.01 0.15 -0.01 -0.26 0.24 1825.29 1.00
f[29] -0.21 0.15 -0.20 -0.46 0.03 1843.03 1.00
f[30] -0.38 0.15 -0.38 -0.62 -0.13 1865.79 1.00
f[31] -0.53 0.15 -0.53 -0.77 -0.28 1888.21 1.00
f[32] -0.67 0.15 -0.67 -0.91 -0.42 1905.84 1.00
f[33] -0.79 0.15 -0.79 -1.05 -0.56 1913.41 1.00
f[34] -0.90 0.15 -0.90 -1.15 -0.66 1911.47 1.00
f[35] -0.99 0.15 -0.99 -1.25 -0.76 1900.04 1.00
f[36] -1.08 0.15 -1.08 -1.32 -0.83 1883.67 1.00
f[37] -1.16 0.15 -1.16 -1.41 -0.91 1869.40 1.00
f[38] -1.22 0.15 -1.22 -1.47 -0.97 1859.36 1.00
f[39] -1.27 0.15 -1.27 -1.53 -1.03 1858.70 1.00
f[40] -1.31 0.16 -1.31 -1.56 -1.05 1874.72 1.00
f[41] -1.33 0.16 -1.33 -1.59 -1.07 1896.21 1.00
f[42] -1.33 0.16 -1.33 -1.59 -1.07 1922.11 1.00
f[43] -1.32 0.16 -1.32 -1.59 -1.07 1950.52 1.00
f[44] -1.28 0.16 -1.28 -1.54 -1.01 1970.88 1.00
f[45] -1.21 0.17 -1.22 -1.49 -0.95 1993.58 1.00
f[46] -1.13 0.18 -1.13 -1.41 -0.83 2029.49 1.00
f[47] -1.03 0.20 -1.03 -1.36 -0.72 2096.23 1.00
f[48] -0.91 0.23 -0.91 -1.26 -0.52 2204.32 1.00
f[49] -0.77 0.26 -0.77 -1.23 -0.36 2320.96 1.00
sigma 0.49 0.06 0.48 0.39 0.57 3292.89 1.00
Number of divergences: 0
sigma_posterior = posterior_samples['sigma']
plt.figure(figsize=(6, 4))
sns.kdeplot(sigma_posterior, fill=True)
plt.axvline(x=0.5, color='r', linestyle='--', label='True value')
plt.xlabel('$\sigma$')
plt.ylabel('Density')
plt.title('Posterior of parameter $\sigma$')
# Show the plot
plt.show()
We have inferred the variance parameter successfully. Estimating lengthscale, especially for less smooth kernels is harder. One more issue is the non-identifiability of the pair lengthscale-variance. Hence, if they both really need to be inferred, strong priors would be beneficial.