Gaussian processes: inference#

In the previous chapter we have seen how to generate Gaussian process priors. But what we usually would want to do is inference, i.e. given observed data, we would like to estimate model parameters.

In a typical setting, GP enters GLMMs as a latent variable and the model has the form

\[\begin{split} \begin{align*} f(x) &\sim \mathcal{GP}(\mu(x), k(x, x')), \\ y | f &\sim \Pi_i p(y_i \vert f(x_i)). \end{align*} \end{split}\]

Here \(\{(x_i, y_i)\}_{i=1}^n\) are pairs of observations \(y_i\), and locations of those observations \(x_i\). The role of \(f(x)\) now is to serve as a latent field capturing dependencies between locations \(x\). The expression \(\Pi_i p(y_i \vert f(x_i))\) provides a likelihood, allowing us to link observed data to the model and enabling parameter inference.

The task that we usually want to solve is twofold: we want to infer parameters involed in the model, as well as make predictions at unobserved locations \(x_*.\)

Gaussian process regression#

The simplest case of the setting described above is the Gaussian process regression where the outcome variable is modelled as a GP with added noise \(\epsilon\). It assumes that the data consists of pairs \(\{(x_i, y_i)\}_{i=1}^n\) and the likelihood is Gaussian with variance \(\sigma^2_\epsilon:\)

\[\begin{split} \begin{align*} f(x) &\sim \mathcal{GP}(\mu(x), k(x, x')),\\ y_i &= f(x_i) + \epsilon_i, \\ \epsilon_i &\sim \mathcal{N}(0, \sigma^2_\epsilon). \end{align*} \end{split}\]

We want to obtain estimates at locations \(\{(x_i, y_i)\}_{i=1}^m\). Recall the conditioning formula: if

\[\begin{split} \begin{align*} \begin{bmatrix} f_n \\ f_m \end{bmatrix} \sim \mathcal{N}\left(\begin{bmatrix} 0 \\ 0\end{bmatrix}, \\ \begin{bmatrix} K_{n n} \quad K_{n m} \\ K_{m n} \quad K_{m m} \end{bmatrix} \right), \end{align*} \end{split}\]

then

\[\begin{split} \begin{align*} f_m | f_n &\sim \mathcal{N}(\mu_{m|n}, K_{m|n}),\\ \mu_{m|n} & = K_{mn} K_{nn}^{-1} f_n,\\ K_{m|n} & = K_{mm} - K_{mn} K_{nn}^{-1} K_{nm} \end{align*}. \end{split}\]

There is an issue with this formula since due to the noise we can not observe \(f_n\), but we observe \(y_n\) instead. The conditional in this case reads as

\[\begin{split} \begin{align*} f_m | y_n &\sim \mathcal{N}(\mu_{m|n}, K_{m|n}),\\ \mu_{m|n} & = K_{mn} (K_{nn} + \sigma^2_\epsilon I )^{-1} y_n,\\ K_{m|n} & = K_{mm} - K_{mn} (K_{nn} + \sigma^2_\epsilon I)^{-1} K_{nm} \end{align*}. \end{split}\]

This is the predictive distribution of the Gaussian process.

Putting the two distributions side-by-side:

  • posterior distribution in the noise-free case

\[ f_m | f_n \sim \mathcal{N}( K_{mn} K_{nn}^{-1} f_n, K_{mm} - K_{mn} K_{nn} ^{-1} K_{nm}) \]
  • posterior distribution in the noisy case

\[ f_m | y_n \sim \mathcal{N}(K_{mn} (K_{nn} + I\sigma^2_\epsilon )^{-1} y_n, K_{mm} - K_{mn} (K_{nn} + I\sigma^2_\epsilon )^{-1} K_{nm}) \]

Parameters \(\theta\) of the model can be learnt using the marginal log-likelihood which for the latter posterior takes the form

\[ \log p(y|\theta)= -\frac{n}{2} \log(2\pi) - \frac{1}{2}\log \vert K_{nn} + \sigma^2_\epsilon I \vert - \frac{1}{2}y^T(K_{nn} + \sigma^2_\epsilon I)^{-1}y \]

Computational complexity#

Note that the predictive distribution involes inversion of a \(n \times n\) matrix. This computation has cubic complexity \(O(n^3)\) with respect to the number of points \(n\) and creates a computational bottleneck when dealing with GP inference.

# imports for this chapter
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
from numpyro.diagnostics import hpdi

import pickle

numpyro.set_host_device_count(4)  # Set the device count to enable parallel sampling
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Let us implement the predictive distribution.

def rbf_kernel(x1, x2, lengthscale=1.0, sigma=1.0):
    """
    Compute the Radial Basis Function (RBF) kernel matrix between two sets of points.

    Args:
    - x1 (array): Array of shape (n1, d) representing the first set of points.
    - x2 (array): Array of shape (n2, d) representing the second set of points.
    - sigma (float): Variance parameter.
    - length_scale (float): Length-scale parameter.
    - jitter (float): Small positive value added to the diagonal elements.

    Returns:
    - K (array): Kernel matrix of shape (n1, n2).
    """
    sq_dist = jnp.sum(x1**2, axis=1).reshape(-1, 1) + jnp.sum(x2**2, axis=1) - 2 * jnp.dot(x1, x2.T)
    K = sigma**2 * jnp.exp(-0.5 / lengthscale**2 * sq_dist)
    return K
def plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true=False):
    """
    Plots the Gaussian process predictive distribution.

    Args:
    - x_obs: Training inputs, shape (n_train_samples, n_features).
    - y_train: Training targets, shape (n_train_samples,).
    - x_pred: Test input points, shape (n_test_samples, n_features).
    - mean: Mean of the predictive distribution, shape (n_test_samples,).
    - variance: Variance of the predictive distribution, shape (n_test_samples,).
    """
    plt.figure(figsize=(8, 6))
    if not f_true is False:
        plt.plot(x_pred, f_true, label='True Function', color='purple')
    plt.scatter(x_obs, y_obs, c='orangered', label='Training Data')
    plt.plot(x_pred, mean, label='Mean Prediction', color='teal')
    plt.fill_between(x_pred.squeeze(), mean - jnp.sqrt(variance), mean + jnp.sqrt(variance), color='teal', alpha=0.3, label='Uncertainty')
    plt.title('Gaussian Process Predictive Distribution')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.show()
def predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=1.0, sigma=1.0, jitter=1e-8):
    """
    Predicts the mean and variance of the Gaussian process at test points.

    Args:
    - x_obs: Training inputs, shape (n_train_samples, n_features).
    - y_obs: Training targets, shape (n_train_samples,).
    - x_pred: Test input points, shape (n_test_samples, n_features).
    - length_scale: Length scale parameter.
    - variance: Variance parameter.
    - jitter: Jitter to ensure computational stability.

    Returns:
    - mean: Mean of the predictive distribution, shape (n_test_samples,).
    - variance: Variance of the predictive distribution, shape (n_test_samples,).
    """
    K = rbf_kernel(x_obs, x_obs, length_scale, sigma) + jitter * jnp.eye(len(x_obs))
    K_inv = jnp.linalg.inv(K)
    K_star = rbf_kernel(x_obs, x_pred, length_scale, sigma)
    K_star_star = rbf_kernel(x_pred, x_pred, length_scale, sigma)

    mean = jnp.dot(K_star.T, jnp.dot(K_inv, y_obs))
    variance = jnp.diag(K_star_star) - jnp.sum(jnp.dot(K_star.T, K_inv) * K_star.T, axis=1)

    return mean, variance
# Example usage:
x_obs = jnp.array([[1], [2], [3]])  # Training inputs
y_obs = jnp.array([3, 4, 2])         # Training targets
x_pred = jnp.linspace(0, 5, 100).reshape(-1, 1)  # Test input points

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred)
plot_gp(x_obs, y_obs, x_pred, mean, variance)
_images/101dcc26446c3086888165387b32289928e57a1e0fdbc417efedc5fa632a4fd6.png

This seem to work well. Let’s take a look at some noisy data.

# Define the true pattern
def f(x):
    return (6 * x/5 - 2)**2 * jnp.sin(12 * x/5 - 4)

# Define the range of x values
x = jnp.linspace(0, 5, 100)

# Generate the true function values
f_true = f(x)

# Generate noise with a normal distribution
noise = 2*jax.random.normal(jax.random.PRNGKey(0), shape=f_true.shape)

# Add noise to the true function values to obtain the observed values (y)
y = f_true + noise

# Indices of observed data points
idx_obs = jnp.array([[10], [50], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, color='orangered', label='Training Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fb52c4af810>
_images/de8d0e0cdd3538c203726d69c8302848ec707d7989a42bd9f3156286e7b11a87.png
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/1ada483330eb72c69512f8934f9fc19334f41e28900cf8fc89d80c8aabc5ce26.png

Let’s add more points and repeat the procedure.

# Indices of observed data points
idx_obs = jnp.array([[10], [12],[20], [68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/6713fb8137a2efe0090389e0d69437c523988e6f79d135532e6812a240714b11.png
# Indices of observed data points
idx_obs = jnp.array([[5], [10], [12],[20], [25], [41], [55],[68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/ca2292d68eadc197f348c41f6d0226f5c29f77f50285fc1b04e2fb3dfa31cbbe.png

As the number of observed data points increases, we observe two phenomena. Fisrtly, the unertainty bounds are shrinking (as expected); however, rather that trying to generalise, our models starts passing through every point.

Non-Gaussian likelihoods#

We have seen that in the Normal-Normal setting, posterior distribution was accessible analytically.

For non-conjugate likelihood models this is not the case. When the likelihood is non-Gaussian, the resulting posterior distribution is no longer Gaussian, and obtaining a closed-form expression for the predictive distribution may be challenging or impossible.

In such cases, computational methods such as Markov Chain Monte Carlo or Variational Inference are often used to approximate the predictive distribution. We will use Numpyro and its MCMC engine for this purpose.

Draw GP priors using Numpyro functionality#

In the previous chapter we saw how to draw GP priors numerically. While we don’t necessarily need Numpyro to do inference in the case of Gaussian likelihood, we will soon see that we will need Numpyro for inference in other cases.

Let us start building that code by drawing GP priors using Numpyro functionality.

def plot_gp_samples(x, samples, ttl="", num_samples=10):
    
    # Plot the samples
    plt.figure(figsize=(6, 4))
    for i in range(num_samples):
        plt.plot(x, samples[i], label=f'Sample {i}')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title(ttl)
    plt.legend()
    plt.tight_layout()
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
    plt.show()
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5, noise=0.5):    
    """
    Gaussian Process prior with a Numpyro model.

    Args:
    - x (jax.numpy.ndarray): Input data points of shape (n, d), where n is the number of points and d is the number of dimensions.
    - kernel_func (function): Kernel function to use. Default is rbf_kernel.
    - lengthscale (float): Length scale parameter for the kernel function. Default is 0.2.
    - jitter (float): Small constant added to the diagonal of the kernel matrix for numerical stability. Default is 1e-5.

    Returns:
    - numpyro.sample: A sample from the Multivariate Normal distribution representing the function values at input points.
    """

    n = x.shape[0]

    K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)

    f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
    
    numpyro.sample("y", dist.Normal(f, noise), obs=y)
# Define parameters
n_points = 50
num_samples = 1000

# Generate random input data
x = jnp.linspace(0, 1, n_points).reshape(-1, 1)

mcmc_predictive = Predictive(model, num_samples=num_samples)
samples = mcmc_predictive(jax.random.PRNGKey(0), x=x)
f_samples = samples['f']

# Calculate mean and standard deviation
mean = jnp.mean(f_samples, axis=0)
std = jnp.std(f_samples, axis=0)

# Visualization
plt.figure(figsize=(8, 6))
for i in range(10):
    plt.plot(x, f_samples[i], color='teal', lw=0.3)
plt.plot(x, mean, 'b-', label='Mean', color='teal')  # Point-wise mean of all samples
plt.fill_between(x.squeeze(), mean - 1.96 * std, mean + 1.96 * std, color='teal', alpha=0.3, label='Uncertainty Bounds')  # Uncertainty bounds
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.title('GP samples obtained using Numpyro')    
plt.grid()
plt.legend()
plt.show()
/tmp/ipykernel_2817/1982372603.py:20: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "b-" (-> color='b'). The keyword argument will take precedence.
  plt.plot(x, mean, 'b-', label='Mean', color='teal')  # Point-wise mean of all samples
_images/02d5a3c8766baa5580a07aaf5d0020968775a2323f8da91409f25afdbaffeb28.png

Inference with Numpyro: known noise#

For inference, we will use the simulated data in the previous step as input.

true_idx = 3

y_samples = samples['y']
f_true = f_samples[true_idx]
y_obs =  y_samples[true_idx]

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fb528711550>
_images/561aa31ba21b5c52e0bda194c66bfdee6a318810b8e39c17a68985613b09d9a8.png

Let us see whether we can recover \(f\) from observed \(y\) values. At this stage we assume that the amount of noise is known.

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()

# Get the posterior samples
posterior_samples = mcmc.get_samples()
f_posterior = posterior_samples['f']
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[15], line 2
      1 # Print summary statistics of posterior
----> 2 mcmc.print_summary()
      4 # Get the posterior samples
      5 posterior_samples = mcmc.get_samples()

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:763, in MCMC.print_summary(self, prob, exclude_deterministic)
    757     if isinstance(state_sample_field, dict):
    758         sites = {
    759             k: v
    760             for k, v in self._states[self._sample_field].items()
    761             if k in state_sample_field
    762         }
--> 763 print_summary(sites, prob=prob)
    764 extra_fields = self.get_extra_fields()
    765 if "diverging" in extra_fields:

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:296, in print_summary(samples, prob, group_by_chain)
    292 if not isinstance(samples, dict):
    293     samples = {
    294         "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
    295     }
--> 296 summary_dict = summary(samples, prob, group_by_chain=True)
    298 row_names = {
    299     k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]"
    300     for k, v in samples.items()
    301 }
    302 max_len = max(max(map(lambda x: len(x), row_names.values())), 10)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:249, in summary(samples, prob, group_by_chain)
    247 summary_dict = {}
    248 for name, value in samples.items():
--> 249     value = device_get(value)
    250     value_flat = np.reshape(value, (-1,) + value.shape[2:])
    251     mean = value_flat.mean(axis=0)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/api.py:2484, in device_get(x)
   2482   except AttributeError:
   2483     pass
-> 2484 return tree_map(_device_get, x)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:344, in tree_map(f, tree, is_leaf, *rest)
    342 leaves, treedef = tree_flatten(tree, is_leaf)
    343 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 344 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:344, in <genexpr>(.0)
    342 leaves, treedef = tree_flatten(tree, is_leaf)
    343 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 344 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/api.py:2445, in _device_get(x)
   2443   return x
   2444 else:
-> 2445   return toarray()

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/array.py:425, in ArrayImpl.__array__(self, dtype, context, copy)
    422 def __array__(self, dtype=None, context=None, copy=None):
    423   # copy argument is supported by np.asarray starting in numpy 2.0
    424   kwds = {} if copy is None else {'copy': copy}
--> 425   return np.asarray(self._value, dtype=dtype, **kwds)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    330 @wraps(func)
    331 def wrapper(*args, **kwargs):
    332   with TraceAnnotation(name, **decorator_kwargs):
--> 333     return func(*args, **kwargs)
    334   return wrapper

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/array.py:645, in ArrayImpl._value(self)
    643 npy_value = np.empty(self.shape, self.dtype)
    644 for i, ind in _cached_index_calc(self.sharding, self.shape):
--> 645   npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
    646 self._npy_value = npy_value
    647 self._npy_value.flags.writeable = False

KeyboardInterrupt: 
# Calculate mean and standard deviation
f_mean = jnp.mean(f_posterior, axis=0)
f_hpdi = hpdi(f_posterior, 0.95)

#plt.plot(x, f_true, color='orange', label='f(x)')
#plt.scatter(x, y_obs, color='red', label='y')
#plt.plot(x, f_mean)
#plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='lightblue', alpha=0.3, label='HPDI')  # Uncertainty bounds
#plt.legend()
#plt.show()

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.plot(x, f_mean, color='teal', label='Mean Prediction')
plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='teal', alpha=0.3, label='HPDI')  # Uncertainty bounds
plt.grid()
plt.legend()
plt.show()
_images/413ba06e8f50070d63bc903f7027ae673afe02cd4d2c31308d73e7ea4c233fda.png

This looks good. But in reality we barely ever know true level of noise.

Inference with Numpyro: estimating noise#

def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5):    
    """
    Gaussian Process prior with a Numpyro model.

    Args:
    - x (jax.numpy.ndarray): Input data points of shape (n, d), where n is the number of points and d is the number of dimensions.
    - kernel_func (function): Kernel function to use. Default is rbf_kernel.
    - lengthscale (float): Length scale parameter for the kernel function. Default is 0.2.
    - jitter (float): Small constant added to the diagonal of the kernel matrix for numerical stability. Default is 1e-5.

    Returns:
    - numpyro.sample: A sample from the Multivariate Normal distribution representing the function values at input points.
    """

    n = x.shape[0]

    K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)

    f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))

    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
    
    numpyro.sample("y", dist.Normal(f, sigma), obs=y)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()

# Get the posterior samples
posterior_samples = mcmc.get_samples()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      f[0]      0.97      0.26      0.97      0.54      1.38   2375.59      1.00
      f[1]      1.09      0.22      1.09      0.74      1.46   2303.95      1.00
      f[2]      1.23      0.19      1.22      0.90      1.53   2235.81      1.00
      f[3]      1.37      0.18      1.37      1.08      1.65   2149.90      1.00
      f[4]      1.53      0.17      1.53      1.26      1.80   2043.35      1.00
      f[5]      1.70      0.16      1.70      1.42      1.95   1926.71      1.00
      f[6]      1.87      0.16      1.87      1.59      2.13   1803.14      1.00
      f[7]      2.05      0.16      2.05      1.79      2.31   1696.37      1.00
      f[8]      2.22      0.16      2.22      1.94      2.47   1624.57      1.00
      f[9]      2.37      0.16      2.37      2.10      2.63   1588.26      1.00
     f[10]      2.51      0.16      2.51      2.25      2.78   1585.35      1.00
     f[11]      2.63      0.16      2.63      2.36      2.89   1611.61      1.00
     f[12]      2.71      0.16      2.71      2.44      2.97   1662.09      1.00
     f[13]      2.76      0.16      2.76      2.51      3.04   1710.88      1.00
     f[14]      2.77      0.16      2.77      2.51      3.03   1751.52      1.00
     f[15]      2.74      0.15      2.74      2.49      3.00   1763.07      1.00
     f[16]      2.66      0.15      2.66      2.40      2.91   1751.03      1.00
     f[17]      2.54      0.15      2.55      2.30      2.80   1728.46      1.00
     f[18]      2.39      0.15      2.39      2.14      2.64   1702.39      1.00
     f[19]      2.20      0.15      2.20      1.96      2.45   1683.25      1.00
     f[20]      1.98      0.15      1.98      1.74      2.23   1670.92      1.00
     f[21]      1.73      0.15      1.73      1.48      1.98   1683.03      1.00
     f[22]      1.47      0.15      1.47      1.22      1.71   1712.11      1.00
     f[23]      1.21      0.15      1.21      0.96      1.46   1747.18      1.00
     f[24]      0.94      0.15      0.94      0.70      1.20   1767.20      1.00
     f[25]      0.68      0.15      0.68      0.44      0.93   1783.08      1.00
     f[26]      0.43      0.15      0.43      0.18      0.68   1796.17      1.00
     f[27]      0.20      0.15      0.20     -0.05      0.45   1809.50      1.00
     f[28]     -0.01      0.15     -0.01     -0.26      0.24   1825.29      1.00
     f[29]     -0.21      0.15     -0.20     -0.46      0.03   1843.03      1.00
     f[30]     -0.38      0.15     -0.38     -0.62     -0.13   1865.79      1.00
     f[31]     -0.53      0.15     -0.53     -0.77     -0.28   1888.21      1.00
     f[32]     -0.67      0.15     -0.67     -0.91     -0.42   1905.84      1.00
     f[33]     -0.79      0.15     -0.79     -1.05     -0.56   1913.41      1.00
     f[34]     -0.90      0.15     -0.90     -1.15     -0.66   1911.47      1.00
     f[35]     -0.99      0.15     -0.99     -1.25     -0.76   1900.04      1.00
     f[36]     -1.08      0.15     -1.08     -1.32     -0.83   1883.67      1.00
     f[37]     -1.16      0.15     -1.16     -1.41     -0.91   1869.40      1.00
     f[38]     -1.22      0.15     -1.22     -1.47     -0.97   1859.36      1.00
     f[39]     -1.27      0.15     -1.27     -1.53     -1.03   1858.70      1.00
     f[40]     -1.31      0.16     -1.31     -1.56     -1.05   1874.72      1.00
     f[41]     -1.33      0.16     -1.33     -1.59     -1.07   1896.21      1.00
     f[42]     -1.33      0.16     -1.33     -1.59     -1.07   1922.11      1.00
     f[43]     -1.32      0.16     -1.32     -1.59     -1.07   1950.52      1.00
     f[44]     -1.28      0.16     -1.28     -1.54     -1.01   1970.88      1.00
     f[45]     -1.21      0.17     -1.22     -1.49     -0.95   1993.58      1.00
     f[46]     -1.13      0.18     -1.13     -1.41     -0.83   2029.49      1.00
     f[47]     -1.03      0.20     -1.03     -1.36     -0.72   2096.23      1.00
     f[48]     -0.91      0.23     -0.91     -1.26     -0.52   2204.32      1.00
     f[49]     -0.77      0.26     -0.77     -1.23     -0.36   2320.96      1.00
     sigma      0.49      0.06      0.48      0.39      0.57   3292.89      1.00

Number of divergences: 0
sigma_posterior = posterior_samples['sigma']

plt.figure(figsize=(6, 4))
sns.kdeplot(sigma_posterior, fill=True)
plt.axvline(x=0.5, color='r', linestyle='--', label='True value')
plt.xlabel('$\sigma$')
plt.ylabel('Density')
plt.title('Posterior of parameter $\sigma$')

# Show the plot
plt.show()
_images/fd9d83bb843713232f013a10d6cff1b339c7efb0f282241bee25fcfe71c55583.png

We have inferred the variance parameter successfully. Estimating lengthscale, especially for less smooth kernels is harder. One more issue is the non-identifiability of the pair lengthscale-variance. Hence, if they both really need to be inferred, strong priors would be beneficial.