Gaussian processes: inference#

In the previous chapter we have seen how to generate Gaussian process priors. But what we usually would want to do is inference, i.e. given observed data, we would like to estimate model parameters.

In a typical setting, GP enters GLMMs as a latent variable and the model has the form

\[\begin{split} \begin{align*} f(x) &\sim \mathcal{GP}(\mu(x), k(x, x')), \\ y | f &\sim \Pi_i p(y_i \vert f(x_i)). \end{align*} \end{split}\]

Here \(\{(x_i, y_i)\}_{i=1}^n\) are pairs of observations \(y_i\), and locations of those observations \(x_i\). The role of \(f(x)\) now is to serve as a latent field capturing dependencies between locations \(x\). The expression \(\Pi_i p(y_i \vert f(x_i))\) provides a likelihood, allowing us to link observed data to the model and enabling parameter inference.

The task that we usually want to solve is twofold: we want to infer parameters involed in the model, as well as make predictions at unobserved locations \(x_*.\)

Gaussian process regression#

The simplest case of the setting described above is the Gaussian process regression where the outcome variable is modelled as a GP with added noise \(\epsilon\). It assumes that the data consists of pairs \(\{(x_i, y_i)\}_{i=1}^n\) and the likelihood is Gaussian with variance \(\sigma^2_\epsilon:\)

\[\begin{split} \begin{align*} f(x) &\sim \mathcal{GP}(\mu(x), k(x, x')),\\ y_i &= f(x_i) + \epsilon_i, \\ \epsilon_i &\sim \mathcal{N}(0, \sigma^2_\epsilon). \end{align*} \end{split}\]

We want to obtain estimates at locations \(\{(x_i, y_i)\}_{i=1}^m\). Recall the conditioning formula: if

\[\begin{split} \begin{align*} \begin{bmatrix} f_n \\ f_m \end{bmatrix} \sim \mathcal{N}\left(\begin{bmatrix} 0 \\ 0\end{bmatrix}, \\ \begin{bmatrix} K_{n n} \quad K_{n m} \\ K_{m n} \quad K_{m m} \end{bmatrix} \right), \end{align*} \end{split}\]

then

\[\begin{split} \begin{align*} f_m | f_n &\sim \mathcal{N}(\mu_{m|n}, K_{m|n}),\\ \mu_{m|n} & = K_{mn} K_{nn}^{-1} f_n,\\ K_{m|n} & = K_{mm} - K_{mn} K_{nn}^{-1} K_{nm} \end{align*}. \end{split}\]

There is an issue with this formula since due to the noise we can not observe \(f_n\), but we observe \(y_n\) instead. The conditional in this case reads as

\[\begin{split} \begin{align*} f_m | y_n &\sim \mathcal{N}(\mu_{m|n}, K_{m|n}),\\ \mu_{m|n} & = K_{mn} (K_{nn} + \sigma^2_\epsilon I )^{-1} y_n,\\ K_{m|n} & = K_{mm} - K_{mn} (K_{nn} + \sigma^2_\epsilon I)^{-1} K_{nm} \end{align*}. \end{split}\]

This is the predictive distribution of the Gaussian process.

Putting the two distributions side-by-side:

  • posterior distribution in the noise-free case

\[ f_m | f_n \sim \mathcal{N}( K_{mn} K_{nn}^{-1} f_n, K_{mm} - K_{mn} K_{nn} ^{-1} K_{nm}) \]
  • posterior distribution in the noisy case

\[ f_m | y_n \sim \mathcal{N}(K_{mn} (K_{nn} + I\sigma^2_\epsilon )^{-1} y_n, K_{mm} - K_{mn} (K_{nn} + I\sigma^2_\epsilon )^{-1} K_{nm}) \]

Parameters \(\theta\) of the model can be learnt using the marginal log-likelihood which for the latter posterior takes the form

\[ \log p(y|\theta)= -\frac{n}{2} \log(2\pi) - \frac{1}{2}\log \vert K_{nn} + \sigma^2_\epsilon I \vert - \frac{1}{2}y^T(K_{nn} + \sigma^2_\epsilon I)^{-1}y \]

Computational complexity#

Note that the predictive distribution involes inversion of a \(n \times n\) matrix. This computation has cubic complexity \(O(n^3)\) with respect to the number of points \(n\) and creates a computational bottleneck when dealing with GP inference.

# imports for this chapter
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
from numpyro.diagnostics import hpdi

import pickle

numpyro.set_host_device_count(4)  # Set the device count to enable parallel sampling
/opt/hostedtoolcache/Python/3.11.9/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Let us implement the predictive distribution.

def rbf_kernel(x1, x2, lengthscale=1.0, sigma=1.0):
    """
    Compute the Radial Basis Function (RBF) kernel matrix between two sets of points.

    Args:
    - x1 (array): Array of shape (n1, d) representing the first set of points.
    - x2 (array): Array of shape (n2, d) representing the second set of points.
    - sigma (float): Variance parameter.
    - length_scale (float): Length-scale parameter.
    - jitter (float): Small positive value added to the diagonal elements.

    Returns:
    - K (array): Kernel matrix of shape (n1, n2).
    """
    sq_dist = jnp.sum(x1**2, axis=1).reshape(-1, 1) + jnp.sum(x2**2, axis=1) - 2 * jnp.dot(x1, x2.T)
    K = sigma**2 * jnp.exp(-0.5 / lengthscale**2 * sq_dist)
    return K
def plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true=False):
    """
    Plots the Gaussian process predictive distribution.

    Args:
    - x_obs: Training inputs, shape (n_train_samples, n_features).
    - y_train: Training targets, shape (n_train_samples,).
    - x_pred: Test input points, shape (n_test_samples, n_features).
    - mean: Mean of the predictive distribution, shape (n_test_samples,).
    - variance: Variance of the predictive distribution, shape (n_test_samples,).
    """
    plt.figure(figsize=(8, 6))
    if not f_true is False:
        plt.plot(x_pred, f_true, label='True Function', color='purple')
    plt.scatter(x_obs, y_obs, c='orangered', label='Training Data')
    plt.plot(x_pred, mean, label='Mean Prediction', color='teal')
    plt.fill_between(x_pred.squeeze(), mean - jnp.sqrt(variance), mean + jnp.sqrt(variance), color='teal', alpha=0.3, label='Uncertainty')
    plt.title('Gaussian Process Predictive Distribution')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.show()
def predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=1.0, sigma=1.0, jitter=1e-8):
    """
    Predicts the mean and variance of the Gaussian process at test points.

    Args:
    - x_obs: Training inputs, shape (n_train_samples, n_features).
    - y_obs: Training targets, shape (n_train_samples,).
    - x_pred: Test input points, shape (n_test_samples, n_features).
    - length_scale: Length scale parameter.
    - variance: Variance parameter.
    - jitter: Jitter to ensure computational stability.

    Returns:
    - mean: Mean of the predictive distribution, shape (n_test_samples,).
    - variance: Variance of the predictive distribution, shape (n_test_samples,).
    """
    K = rbf_kernel(x_obs, x_obs, length_scale, sigma) + jitter * jnp.eye(len(x_obs))
    K_inv = jnp.linalg.inv(K)
    K_star = rbf_kernel(x_obs, x_pred, length_scale, sigma)
    K_star_star = rbf_kernel(x_pred, x_pred, length_scale, sigma)

    mean = jnp.dot(K_star.T, jnp.dot(K_inv, y_obs))
    variance = jnp.diag(K_star_star) - jnp.sum(jnp.dot(K_star.T, K_inv) * K_star.T, axis=1)

    return mean, variance
# Example usage:
x_obs = jnp.array([[1], [2], [3]])  # Training inputs
y_obs = jnp.array([3, 4, 2])         # Training targets
x_pred = jnp.linspace(0, 5, 100).reshape(-1, 1)  # Test input points

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred)
plot_gp(x_obs, y_obs, x_pred, mean, variance)
_images/101dcc26446c3086888165387b32289928e57a1e0fdbc417efedc5fa632a4fd6.png

This seem to work well. Let’s take a look at some noisy data.

# Define the true pattern
def f(x):
    return (6 * x/5 - 2)**2 * jnp.sin(12 * x/5 - 4)

# Define the range of x values
x = jnp.linspace(0, 5, 100)

# Generate the true function values
f_true = f(x)

# Generate noise with a normal distribution
noise = 2*jax.random.normal(jax.random.PRNGKey(0), shape=f_true.shape)

# Add noise to the true function values to obtain the observed values (y)
y = f_true + noise

# Indices of observed data points
idx_obs = jnp.array([[10], [50], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, color='orangered', label='Training Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7f54c4fd1650>
_images/de8d0e0cdd3538c203726d69c8302848ec707d7989a42bd9f3156286e7b11a87.png
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/1ada483330eb72c69512f8934f9fc19334f41e28900cf8fc89d80c8aabc5ce26.png

Let’s add more points and repeat the procedure.

# Indices of observed data points
idx_obs = jnp.array([[10], [12],[20], [68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/6713fb8137a2efe0090389e0d69437c523988e6f79d135532e6812a240714b11.png
# Indices of observed data points
idx_obs = jnp.array([[5], [10], [12],[20], [25], [41], [55],[68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)

mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
_images/ca2292d68eadc197f348c41f6d0226f5c29f77f50285fc1b04e2fb3dfa31cbbe.png

As the number of observed data points increases, we observe two phenomena. Fisrtly, the unertainty bounds are shrinking (as expected); however, rather that trying to generalise, our models starts passing through every point.

Non-Gaussian likelihoods#

We have seen that in the Normal-Normal setting, posterior distribution was accessible analytically.

For non-conjugate likelihood models this is not the case. When the likelihood is non-Gaussian, the resulting posterior distribution is no longer Gaussian, and obtaining a closed-form expression for the predictive distribution may be challenging or impossible.

In such cases, computational methods such as Markov Chain Monte Carlo or Variational Inference are often used to approximate the predictive distribution. We will use Numpyro and its MCMC engine for this purpose.

Draw GP priors using Numpyro functionality#

In the previous chapter we saw how to draw GP priors numerically. While we don’t necessarily need Numpyro to do inference in the case of Gaussian likelihood, we will soon see that we will need Numpyro for inference in other cases.

Let us start building that code by drawing GP priors using Numpyro functionality.

def plot_gp_samples(x, samples, ttl="", num_samples=10):
    
    # Plot the samples
    plt.figure(figsize=(6, 4))
    for i in range(num_samples):
        plt.plot(x, samples[i], label=f'Sample {i}')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title(ttl)
    plt.legend()
    plt.tight_layout()
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
    plt.show()
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5, noise=0.5):    
    """
    Gaussian Process prior with a Numpyro model.

    Args:
    - x (jax.numpy.ndarray): Input data points of shape (n, d), where n is the number of points and d is the number of dimensions.
    - kernel_func (function): Kernel function to use. Default is rbf_kernel.
    - lengthscale (float): Length scale parameter for the kernel function. Default is 0.2.
    - jitter (float): Small constant added to the diagonal of the kernel matrix for numerical stability. Default is 1e-5.

    Returns:
    - numpyro.sample: A sample from the Multivariate Normal distribution representing the function values at input points.
    """

    n = x.shape[0]

    K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)

    f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
    
    numpyro.sample("y", dist.Normal(f, noise), obs=y)
# Define parameters
n_points = 50
num_samples = 1000

# Generate random input data
x = jnp.linspace(0, 1, n_points).reshape(-1, 1)

mcmc_predictive = Predictive(model, num_samples=num_samples)
samples = mcmc_predictive(jax.random.PRNGKey(0), x=x)
f_samples = samples['f']

# Calculate mean and standard deviation
mean = jnp.mean(f_samples, axis=0)
std = jnp.std(f_samples, axis=0)

# Visualization
plt.figure(figsize=(8, 6))
for i in range(10):
    plt.plot(x, f_samples[i], color='teal', lw=0.3)
plt.plot(x, mean, 'b-', label='Mean', color='teal')  # Point-wise mean of all samples
plt.fill_between(x.squeeze(), mean - 1.96 * std, mean + 1.96 * std, color='teal', alpha=0.3, label='Uncertainty Bounds')  # Uncertainty bounds
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.title('GP samples obtained using Numpyro')    
plt.grid()
plt.legend()
plt.show()
/tmp/ipykernel_2765/1982372603.py:20: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "b-" (-> color='b'). The keyword argument will take precedence.
  plt.plot(x, mean, 'b-', label='Mean', color='teal')  # Point-wise mean of all samples
_images/02d5a3c8766baa5580a07aaf5d0020968775a2323f8da91409f25afdbaffeb28.png

Inference with Numpyro: known noise#

For inference, we will use the simulated data in the previous step as input.

true_idx = 3

y_samples = samples['y']
f_true = f_samples[true_idx]
y_obs =  y_samples[true_idx]

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7f54ac9c4290>
_images/561aa31ba21b5c52e0bda194c66bfdee6a318810b8e39c17a68985613b09d9a8.png

Let us see whether we can recover \(f\) from observed \(y\) values. At this stage we assume that the amount of noise is known.

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()

# Get the posterior samples
posterior_samples = mcmc.get_samples()
f_posterior = posterior_samples['f']
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      f[0]      0.98      0.26      0.97      0.56      1.42   2118.70      1.00
      f[1]      1.10      0.22      1.09      0.73      1.47   1892.73      1.00
      f[2]      1.23      0.20      1.23      0.90      1.54   1738.63      1.00
      f[3]      1.38      0.18      1.38      1.09      1.67   1690.52      1.00
      f[4]      1.54      0.17      1.54      1.26      1.81   1697.34      1.00
      f[5]      1.70      0.16      1.71      1.43      1.97   1695.92      1.00
      f[6]      1.88      0.16      1.88      1.61      2.13   1765.56      1.00
      f[7]      2.05      0.16      2.05      1.78      2.31   1742.35      1.00
      f[8]      2.21      0.16      2.21      1.95      2.47   1720.36      1.00
      f[9]      2.37      0.16      2.37      2.11      2.63   1688.84      1.00
     f[10]      2.51      0.16      2.50      2.25      2.76   1656.78      1.00
     f[11]      2.62      0.16      2.62      2.36      2.87   1631.58      1.00
     f[12]      2.70      0.16      2.70      2.44      2.95   1616.84      1.00
     f[13]      2.75      0.16      2.75      2.48      3.00   1608.86      1.00
     f[14]      2.76      0.16      2.76      2.51      3.04   1605.41      1.00
     f[15]      2.73      0.16      2.73      2.47      3.00   1607.43      1.00
     f[16]      2.65      0.16      2.65      2.40      2.93   1608.63      1.00
     f[17]      2.54      0.16      2.54      2.29      2.81   1607.32      1.00
     f[18]      2.38      0.16      2.38      2.14      2.66   1604.98      1.00
     f[19]      2.19      0.16      2.19      1.93      2.45   1599.29      1.00
     f[20]      1.98      0.16      1.97      1.71      2.23   1589.33      1.00
     f[21]      1.73      0.16      1.73      1.47      1.99   1576.19      1.00
     f[22]      1.47      0.16      1.48      1.22      1.73   1564.14      1.00
     f[23]      1.21      0.16      1.21      0.95      1.46   1540.29      1.00
     f[24]      0.94      0.16      0.94      0.68      1.20   1477.81      1.00
     f[25]      0.68      0.16      0.68      0.42      0.94   1423.51      1.00
     f[26]      0.43      0.16      0.43      0.14      0.67   1392.02      1.00
     f[27]      0.20      0.16      0.20     -0.08      0.44   1390.13      1.00
     f[28]     -0.02      0.16     -0.02     -0.28      0.23   1410.04      1.00
     f[29]     -0.22      0.16     -0.22     -0.48      0.04   1446.05      1.00
     f[30]     -0.39      0.16     -0.39     -0.65     -0.14   1471.46      1.00
     f[31]     -0.55      0.16     -0.55     -0.81     -0.30   1486.64      1.00
     f[32]     -0.68      0.16     -0.68     -0.94     -0.43   1491.98      1.00
     f[33]     -0.81      0.16     -0.81     -1.07     -0.56   1496.58      1.00
     f[34]     -0.91      0.16     -0.91     -1.16     -0.64   1500.14      1.00
     f[35]     -1.01      0.16     -1.01     -1.26     -0.75   1499.52      1.00
     f[36]     -1.09      0.16     -1.09     -1.36     -0.84   1495.56      1.00
     f[37]     -1.16      0.16     -1.16     -1.44     -0.91   1497.25      1.00
     f[38]     -1.22      0.16     -1.22     -1.49     -0.96   1506.54      1.00
     f[39]     -1.27      0.16     -1.27     -1.53     -1.00   1521.17      1.00
     f[40]     -1.31      0.16     -1.31     -1.57     -1.03   1542.43      1.00
     f[41]     -1.33      0.16     -1.33     -1.61     -1.07   1564.42      1.00
     f[42]     -1.33      0.16     -1.33     -1.60     -1.07   1592.09      1.00
     f[43]     -1.31      0.16     -1.31     -1.57     -1.04   1617.14      1.00
     f[44]     -1.27      0.16     -1.27     -1.53     -1.00   1652.06      1.00
     f[45]     -1.21      0.17     -1.21     -1.48     -0.94   1695.73      1.00
     f[46]     -1.13      0.18     -1.13     -1.43     -0.85   1768.17      1.00
     f[47]     -1.03      0.20     -1.02     -1.35     -0.71   1902.37      1.00
     f[48]     -0.92      0.22     -0.91     -1.29     -0.56   2084.98      1.00
     f[49]     -0.79      0.26     -0.79     -1.24     -0.39   2299.99      1.00

Number of divergences: 0
# Calculate mean and standard deviation
f_mean = jnp.mean(f_posterior, axis=0)
f_hpdi = hpdi(f_posterior, 0.95)

#plt.plot(x, f_true, color='orange', label='f(x)')
#plt.scatter(x, y_obs, color='red', label='y')
#plt.plot(x, f_mean)
#plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='lightblue', alpha=0.3, label='HPDI')  # Uncertainty bounds
#plt.legend()
#plt.show()

plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.plot(x, f_mean, color='teal', label='Mean Prediction')
plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='teal', alpha=0.3, label='HPDI')  # Uncertainty bounds
plt.grid()
plt.legend()
plt.show()
_images/d331571e670bb2fccda71138e58e22951f47e450cb6a1eb263ae4b9c831e6091.png

This looks good. But in reality we barely ever know true level of noise.

Inference with Numpyro: estimating noise#

def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5):    
    """
    Gaussian Process prior with a Numpyro model.

    Args:
    - x (jax.numpy.ndarray): Input data points of shape (n, d), where n is the number of points and d is the number of dimensions.
    - kernel_func (function): Kernel function to use. Default is rbf_kernel.
    - lengthscale (float): Length scale parameter for the kernel function. Default is 0.2.
    - jitter (float): Small constant added to the diagonal of the kernel matrix for numerical stability. Default is 1e-5.

    Returns:
    - numpyro.sample: A sample from the Multivariate Normal distribution representing the function values at input points.
    """

    n = x.shape[0]

    K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)

    f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))

    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
    
    numpyro.sample("y", dist.Normal(f, sigma), obs=y)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()

# Get the posterior samples
posterior_samples = mcmc.get_samples()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      f[0]      0.97      0.26      0.98      0.55      1.39   2647.40      1.00
      f[1]      1.09      0.22      1.09      0.74      1.46   2467.76      1.00
      f[2]      1.23      0.19      1.22      0.91      1.54   2292.04      1.00
      f[3]      1.37      0.17      1.37      1.09      1.65   2170.52      1.00
      f[4]      1.53      0.16      1.53      1.25      1.79   2105.90      1.00
      f[5]      1.70      0.16      1.70      1.44      1.97   2062.95      1.00
      f[6]      1.87      0.16      1.87      1.62      2.13   2007.87      1.00
      f[7]      2.05      0.16      2.04      1.79      2.31   1955.28      1.00
      f[8]      2.21      0.16      2.21      1.95      2.47   1913.67      1.00
      f[9]      2.37      0.16      2.37      2.13      2.65   1880.37      1.00
     f[10]      2.51      0.16      2.51      2.26      2.77   1853.88      1.00
     f[11]      2.63      0.16      2.63      2.38      2.90   1839.64      1.00
     f[12]      2.71      0.16      2.71      2.47      2.98   1839.25      1.00
     f[13]      2.76      0.16      2.77      2.50      3.01   1850.16      1.00
     f[14]      2.77      0.16      2.78      2.52      3.03   1876.81      1.00
     f[15]      2.74      0.16      2.74      2.50      3.01   1908.50      1.00
     f[16]      2.67      0.16      2.67      2.43      2.93   1940.64      1.00
     f[17]      2.55      0.15      2.55      2.30      2.80   1972.30      1.00
     f[18]      2.39      0.15      2.39      2.15      2.65   1992.71      1.00
     f[19]      2.20      0.15      2.20      1.94      2.44   1996.15      1.00
     f[20]      1.98      0.15      1.98      1.74      2.23   1981.78      1.00
     f[21]      1.73      0.15      1.73      1.48      1.97   1944.90      1.00
     f[22]      1.47      0.15      1.47      1.23      1.72   1891.49      1.00
     f[23]      1.20      0.15      1.20      0.97      1.46   1831.85      1.00
     f[24]      0.94      0.15      0.93      0.69      1.18   1776.07      1.00
     f[25]      0.67      0.15      0.67      0.44      0.93   1739.13      1.00
     f[26]      0.43      0.15      0.42      0.18      0.68   1717.66      1.00
     f[27]      0.19      0.15      0.19     -0.06      0.43   1714.01      1.00
     f[28]     -0.02      0.15     -0.03     -0.27      0.23   1693.82      1.00
     f[29]     -0.21      0.15     -0.22     -0.47      0.04   1701.09      1.00
     f[30]     -0.39      0.15     -0.39     -0.64     -0.13   1715.85      1.00
     f[31]     -0.54      0.15     -0.54     -0.80     -0.29   1710.70      1.00
     f[32]     -0.68      0.16     -0.68     -0.94     -0.43   1709.09      1.00
     f[33]     -0.80      0.16     -0.80     -1.05     -0.55   1723.38      1.00
     f[34]     -0.90      0.15     -0.91     -1.15     -0.65   1755.56      1.00
     f[35]     -1.00      0.15     -1.00     -1.25     -0.75   1765.64      1.00
     f[36]     -1.09      0.15     -1.09     -1.34     -0.83   1758.27      1.00
     f[37]     -1.16      0.15     -1.16     -1.42     -0.91   1717.20      1.00
     f[38]     -1.23      0.16     -1.23     -1.48     -0.97   1689.26      1.00
     f[39]     -1.28      0.16     -1.28     -1.54     -1.03   1893.12      1.00
     f[40]     -1.32      0.16     -1.32     -1.57     -1.05   1988.40      1.00
     f[41]     -1.34      0.16     -1.34     -1.59     -1.07   2056.68      1.00
     f[42]     -1.34      0.16     -1.34     -1.61     -1.09   2135.72      1.00
     f[43]     -1.33      0.16     -1.33     -1.60     -1.08   2207.05      1.00
     f[44]     -1.29      0.16     -1.29     -1.54     -1.01   2240.76      1.00
     f[45]     -1.23      0.16     -1.23     -1.49     -0.96   2262.08      1.00
     f[46]     -1.14      0.17     -1.14     -1.42     -0.86   2300.85      1.00
     f[47]     -1.04      0.19     -1.04     -1.35     -0.73   2395.23      1.00
     f[48]     -0.92      0.22     -0.92     -1.27     -0.57   2516.00      1.00
     f[49]     -0.78      0.26     -0.78     -1.18     -0.35   2675.33      1.00
     sigma      0.48      0.05      0.48      0.40      0.58   3694.62      1.00

Number of divergences: 0
sigma_posterior = posterior_samples['sigma']

plt.figure(figsize=(6, 4))
sns.kdeplot(sigma_posterior, fill=True)
plt.axvline(x=0.5, color='r', linestyle='--', label='True value')
plt.xlabel('$\sigma$')
plt.ylabel('Density')
plt.title('Posterior of parameter $\sigma$')

# Show the plot
plt.show()
_images/6a766b1a599e04e982dd2fb0b865d1f0e79eedaf675c7a9d82701f341ccc9b22.png

We have inferred the variance parameter successfully. Estimating lengthscale, especially for less smooth kernels is harder. One more issue is the non-identifiability of the pair lengthscale-variance. Hence, if they both really need to be inferred, strong priors would be beneficial.