Gaussian processes: inference#
In the previous chapter we have seen how to generate Gaussian process priors. But what we usually would want to do is inference, i.e. given observed data, we would like to estimate model parameters, and, potentially, make predictions at unobserved locations.
Marginalization and Conditioning#
Assume we have two sets of coordinates \(x_n = (x_n^1, x_n^2, ..., x_n^n), x_m = (x_m^1, x_m^2, ..., x_m^m)\) and a GP over them:
Marginalization allows us to extract partial information from multivariate probability distribution:
Conditioning allows us to determine the probability of one subset of variables given another subset. Similar to marginalization, this operation is also closed and yields a modified Gaussian distribution:
where
Group Task
Write down these formulas for the \(d=2\) case, i.e. when both \(f_n\) and \(f_m\) have only one component each.
Let us visualise such a conditional.
# imports for this chapter
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpyro
import numpyro.distributions as dist
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
from numpyro.diagnostics import hpdi
import pickle
numpyro.set_host_device_count(4) # Set the device count to enable parallel sampling
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
def Gaussian_conditional(mean, cov, x=None,y=None):
assert not (x is None and y is None) and not (x is not None and y is not None)
if x is not None:
var = cov[1,1] - cov[1,0] * cov[0,0] ** (-1) * cov[0,1]
mu = mean[1] + cov[1,0] * cov[0,0] ** (-1) * (x - mean[0])
else:
var = cov[0,0] - cov[0,1] * cov[1,1] ** (-1) * cov[1,0]
mu = mean[0] + cov[0,1] * cov[1,1] ** (-1) * (y - mean[1])
return mu, var**0.5
# parameters for the 2D Gaussian distribution
mu1 = 0.7
mu2 = 1.6
sigma1 = 1.8
sigma2 = 1.7
rho = 0.529
y1 = -0.4
y2 = 2.2
# generate data points from the 2D Gaussian distribution
mu = jnp.array([mu1, mu2])
K = jnp.array([[sigma1**2, rho*sigma1*sigma2],[rho*sigma2*sigma1, sigma2**2]])
num_samples = 1000
data = np.random.multivariate_normal(mu, K, num_samples)
# calculate marginal distributions
y_values = jnp.linspace(-8, 8, 300)
normal1 = dist.Normal(loc=mu1, scale=sigma1)
normal2 = dist.Normal(loc=mu2, scale=sigma2)
density_x1 = jnp.exp(normal1.log_prob(y_values))
density_x2 = jnp.exp(normal2.log_prob(y_values))
# compute conditionals
cond_mu_x1, cond_sigma_x1 = Gaussian_conditional(mu, K, x=None, y=y2)
cond_mu_x2, cond_sigma_x2 = Gaussian_conditional(mu, K, x=y1, y=None)
cond_density_x1 = jnp.exp(dist.Normal(loc=cond_mu_x1, scale=cond_sigma_x1).log_prob(y_values))
cond_density_x2 = jnp.exp(dist.Normal(loc=cond_mu_x2, scale=cond_sigma_x2).log_prob(y_values))
fig = plt.figure(figsize=(10, 10))
gs = fig.add_gridspec(3, 3)
# main plot (2D Gaussian distribution)
ax_main = fig.add_subplot(gs[1:3, :2])
ax_main.scatter(data[:, 0], data[:, 1], alpha=0.5, label='2d Gaussian')
ax_main.set_xlabel('$x_1$')
ax_main.set_ylabel('$x_2$')
ax_main.axvline(y1, lw=2, c='C2', linestyle = '--')
ax_main.axhline(y2, lw=2, c='C1', linestyle = '--')
ax_main.legend()
ax_main.grid(True)
# marginal x1 plot
ax_marginal_x = fig.add_subplot(gs[0, :2], sharex=ax_main)
ax_marginal_x.plot(y_values, density_x1, label='Marginal $y_1$')
ax_marginal_x.plot(y_values, cond_density_x1, label='Conditional $y_1$', linestyle = '--', c='C1')
ax_marginal_x.legend()
ax_marginal_x.grid(True)
# marginal x2 plot
ax_marginal_y = fig.add_subplot(gs[1:3, 2], sharey=ax_main)
ax_marginal_y.plot(density_x2, y_values, label='Marginal $Y_2$' )
ax_marginal_y.plot(cond_density_x2, y_values, label='Conditional $Y_2$', linestyle='--', c='C2')
ax_marginal_y.legend()
ax_marginal_y.grid(True)
plt.tight_layout()
plt.show()
Inference in GLMMs#
In a typical setting, GP enters GLMMs as a latent variable and the model has the form
Here \(\{(x_i, y_i)\}_{i=1}^n\) are pairs of observations \(y_i\), and locations of those observations \(x_i\). The role of \(f(x)\) now is to serve as a latent field capturing dependencies between locations \(x\). The expression \(\Pi_i p(y_i \vert f(x_i))\) provides a likelihood, allowing us to link observed data to the model and enabling parameter inference.
The task that we usually want to solve is twofold: we want to infer parameters involed in the model, as well as make predictions at unobserved locations \(x_*.\)
Gaussian process regression#
The simplest case of the setting described above is the Gaussian process regression where the outcome variable is modelled as a GP with added noise \(\epsilon\). It assumes that the data consists of pairs \(\{(x_i, y_i)\}_{i=1}^n\) and the likelihood is Gaussian with variance \(\sigma^2_\epsilon:\)
We want to obtain estimates at locations \(\{(x_i, y_i)\}_{i=1}^m\). Recall the conditioning formula: if
then
There is an issue with this formula since due to the noise we can not observe \(f_n\), but we observe \(y_n\) instead. The conditional in this case reads as
This is the predictive distribution of the Gaussian process.
Putting the two distributions side-by-side:
posterior distribution in the noise-free case
posterior distribution in the noisy case
Parameters \(\theta\) of the model can be learnt using the marginal log-likelihood which for the latter posterior takes the form
Computational complexity#
Note that the predictive distribution involes inversion of a \(n \times n\) matrix. This computation has cubic complexity \(O(n^3)\) with respect to the number of points \(n\) and creates a computational bottleneck when dealing with GP inference.
Let us implement the predictive distribution.
def rbf_kernel(x1, x2, lengthscale=1.0, sigma=1.0):
"""
compute the Radial Basis Function (RBF) kernel matrix between two sets of points
args:
- x1 (array): array of shape (n1, d) representing the first set of points
- x2 (array): array of shape (n2, d) representing the second set of points
- sigma (float): variance parameter
- length_scale (float): length-scale parameter
- jitter (float): small positive value added to the diagonal elementsr
returns:
- K (array): kernel matrix of shape (n1, n2)
"""
sq_dist = jnp.sum(x1**2, axis=1).reshape(-1, 1) + jnp.sum(x2**2, axis=1) - 2 * jnp.dot(x1, x2.T)
K = sigma**2 * jnp.exp(-0.5 / lengthscale**2 * sq_dist)
return K
def plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true=False):
"""
plots the Gaussian process predictive distribution
args:
- x_obs: training inputs, shape (n_train_samples, n_features)
- y_train: training targets, shape (n_train_samples,)
- x_pred: test input points, shape (n_test_samples, n_features)
- mean: mean of the predictive distribution, shape (n_test_samples,)
- variance: variance of the predictive distribution, shape (n_test_samples,)
"""
plt.figure(figsize=(8, 6))
if not f_true is False:
plt.plot(x_pred, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, c='orangered', label='Training Data')
plt.plot(x_pred, mean, label='Mean Prediction', color='teal')
plt.fill_between(x_pred.squeeze(), mean - jnp.sqrt(variance), mean + jnp.sqrt(variance), color='teal', alpha=0.3, label='Uncertainty')
plt.title('Gaussian Process Predictive Distribution')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid()
plt.show()
def predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=1.0, sigma=1.0, jitter=1e-8):
"""
predicts the mean and variance of the Gaussian process at test points
args:
- x_obs: training inputs, shape (n_train_samples, n_features)
- y_obs: training targets, shape (n_train_samples,)
- x_pred: test input points, shape (n_test_samples, n_features)
- length_scale: length-scale parameter
- variance: variance parameter
- jitter: jitter to ensure computational stability
returns:
- mean: Mean of the predictive distribution, shape (n_test_samples,).
- variance: Variance of the predictive distribution, shape (n_test_samples,).
"""
K = rbf_kernel(x_obs, x_obs, length_scale, sigma) + jitter * jnp.eye(len(x_obs))
K_inv = jnp.linalg.inv(K)
K_star = rbf_kernel(x_obs, x_pred, length_scale, sigma)
K_star_star = rbf_kernel(x_pred, x_pred, length_scale, sigma)
mean = jnp.dot(K_star.T, jnp.dot(K_inv, y_obs))
variance = jnp.diag(K_star_star) - jnp.sum(jnp.dot(K_star.T, K_inv) * K_star.T, axis=1)
return mean, variance
# example
x_obs = jnp.array([[1], [2], [3]]) # training inputs
y_obs = jnp.array([3, 4, 2]) # training targets
x_pred = jnp.linspace(0, 5, 100).reshape(-1, 1) # test inputs
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred)
plot_gp(x_obs, y_obs, x_pred, mean, variance)
This seem to work well. Let’s take a look at some noisy data.
# the true pattern
def f(x):
return (6 * x/5 - 2)**2 * jnp.sin(12 * x/5 - 4)
# generate true function values
x = jnp.linspace(0, 5, 100)
f_true = f(x)
# sample noise with a normal distribution
noise = 2*jax.random.normal(jax.random.PRNGKey(0), shape=f_true.shape)
# add noise to the true function values to obtain the observed values (y)
y = f_true + noise
# indices of observed data points
idx_obs = jnp.array([[10], [50], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)
plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x_obs, y_obs, color='orangered', label='Training Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fce7c197410>
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
Let’s add more points and repeat the procedure.
# indices of observed data points
idx_obs = jnp.array([[10], [12],[20], [68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
# Indices of observed data points
idx_obs = jnp.array([[5], [10], [12],[20], [25], [41], [55],[68], [73], [85], [90]])
x_obs = x[idx_obs]
y_obs = y[idx_obs].reshape(-1)
mean, variance = predict_gaussian_process(x_obs, y_obs, x_pred, length_scale=0.5, sigma=4)
plot_gp(x_obs, y_obs, x_pred, mean, variance, f_true)
As the number of observed data points increases, we observe two phenomena. Fisrtly, the unertainty bounds are shrinking (as expected); however, rather that trying to generalise, our models starts passing through every point.
Non-Gaussian likelihoods#
We have seen that in the Normal-Normal setting, posterior distribution was accessible analytically.
For non-conjugate likelihood models this is not the case. When the likelihood is non-Gaussian, the resulting posterior distribution is no longer Gaussian, and obtaining a closed-form expression for the predictive distribution may be challenging or impossible.
In such cases, computational methods such as Markov Chain Monte Carlo or Variational Inference are often used to approximate the predictive distribution. We will use Numpyro and its MCMC engine for this purpose.
Draw GP priors using Numpyro functionality#
In the previous chapter we saw how to draw GP priors numerically. While we don’t necessarily need Numpyro to do inference in the case of Gaussian likelihood, we will soon see that we will need Numpyro for inference in other cases.
Let us start building that code by drawing GP priors using Numpyro functionality.
def plot_gp_samples(x, samples, ttl="", num_samples=10):
plt.figure(figsize=(6, 4))
for i in range(num_samples):
plt.plot(x, samples[i], label=f'Sample {i}')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title(ttl)
plt.legend()
plt.tight_layout()
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
plt.show()
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5, noise=0.5):
"""
Gaussian Process prior with a Numpyro model
args:
- x (jax.numpy.ndarray): input data points of shape (n, d), where n is the number of points and d is the number of dimensions
- kernel_func (function): kernel function to use
- lengthscale (float): length-scale parameter
- jitter (float): jitter for numerical stability
returns:
- numpyro.sample: a sample from the Multivariate Normal distribution representing the noisy function values at input points
"""
n = x.shape[0]
K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)
f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
numpyro.sample("y", dist.Normal(f, noise), obs=y)
n_points = 50
num_samples = 1000
# input locations
x = jnp.linspace(0, 1, n_points).reshape(-1, 1)
mcmc_predictive = Predictive(model, num_samples=num_samples)
samples = mcmc_predictive(jax.random.PRNGKey(0), x=x)
f_samples = samples['f']
# calculate mean and standard deviation
mean = jnp.mean(f_samples, axis=0)
std = jnp.std(f_samples, axis=0)
plt.figure(figsize=(8, 6))
for i in range(10):
plt.plot(x, f_samples[i], color='teal', lw=0.3)
plt.plot(x, mean, 'b-', label='Mean', color='teal') # point-wise mean of all samples
plt.fill_between(x.squeeze(), mean - 1.96 * std, mean + 1.96 * std, color='teal', alpha=0.3, label='Uncertainty Bounds') # uncertainty bounds
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.title('GP samples obtained using Numpyro')
plt.grid()
plt.legend()
plt.show()
/tmp/ipykernel_2698/2836025616.py:19: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string "b-" (-> color='b'). The keyword argument will take precedence.
plt.plot(x, mean, 'b-', label='Mean', color='teal') # point-wise mean of all samples
Inference with Numpyro: known noise#
For inference, we will use the simulated data in the previous step as input.
true_idx = 3
y_samples = samples['y']
f_true = f_samples[true_idx]
y_obs = y_samples[true_idx]
plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7fce543f9350>
Let us see whether we can recover \(f\) from observed \(y\) values. At this stage we assume that the amount of noise is known.
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# print summary statistics of posterior
mcmc.print_summary()
# get the posterior samples
posterior_samples = mcmc.get_samples()
f_posterior = posterior_samples['f']
mean std median 5.0% 95.0% n_eff r_hat
f[0] 0.98 0.26 0.97 0.56 1.42 2120.13 1.00
f[1] 1.10 0.22 1.09 0.73 1.47 1894.31 1.00
f[2] 1.23 0.20 1.23 0.90 1.54 1740.22 1.00
f[3] 1.38 0.18 1.38 1.09 1.67 1692.13 1.00
f[4] 1.54 0.17 1.54 1.26 1.81 1698.71 1.00
f[5] 1.70 0.16 1.71 1.43 1.97 1697.40 1.00
f[6] 1.88 0.16 1.88 1.61 2.13 1766.68 1.00
f[7] 2.05 0.16 2.05 1.78 2.31 1743.53 1.00
f[8] 2.21 0.16 2.21 1.95 2.47 1721.54 1.00
f[9] 2.37 0.16 2.37 2.11 2.63 1690.00 1.00
f[10] 2.51 0.16 2.50 2.25 2.76 1657.94 1.00
f[11] 2.62 0.16 2.62 2.36 2.87 1632.71 1.00
f[12] 2.70 0.16 2.70 2.44 2.95 1617.95 1.00
f[13] 2.75 0.16 2.75 2.48 3.00 1609.96 1.00
f[14] 2.76 0.16 2.76 2.51 3.04 1606.50 1.00
f[15] 2.73 0.16 2.73 2.47 3.00 1608.51 1.00
f[16] 2.65 0.16 2.65 2.40 2.93 1609.70 1.00
f[17] 2.54 0.16 2.54 2.29 2.81 1608.39 1.00
f[18] 2.38 0.16 2.38 2.14 2.66 1606.05 1.00
f[19] 2.19 0.16 2.19 1.93 2.45 1600.37 1.00
f[20] 1.98 0.16 1.97 1.71 2.23 1590.41 1.00
f[21] 1.73 0.16 1.73 1.47 1.99 1577.27 1.00
f[22] 1.47 0.16 1.48 1.22 1.73 1565.24 1.00
f[23] 1.21 0.16 1.21 0.95 1.46 1541.40 1.00
f[24] 0.94 0.16 0.94 0.68 1.20 1479.06 1.00
f[25] 0.68 0.16 0.68 0.42 0.94 1424.89 1.00
f[26] 0.43 0.16 0.43 0.14 0.67 1393.53 1.00
f[27] 0.20 0.16 0.20 -0.08 0.44 1391.66 1.00
f[28] -0.02 0.16 -0.02 -0.28 0.23 1411.48 1.00
f[29] -0.22 0.16 -0.22 -0.48 0.04 1447.38 1.00
f[30] -0.39 0.16 -0.39 -0.65 -0.14 1472.72 1.00
f[31] -0.55 0.16 -0.55 -0.81 -0.30 1487.86 1.00
f[32] -0.68 0.16 -0.68 -0.94 -0.43 1493.16 1.00
f[33] -0.81 0.16 -0.81 -1.07 -0.56 1497.70 1.00
f[34] -0.91 0.16 -0.91 -1.16 -0.64 1501.22 1.00
f[35] -1.01 0.16 -1.01 -1.26 -0.75 1500.57 1.00
f[36] -1.09 0.16 -1.09 -1.36 -0.84 1496.60 1.00
f[37] -1.16 0.16 -1.16 -1.44 -0.91 1498.29 1.00
f[38] -1.22 0.16 -1.22 -1.49 -0.96 1507.57 1.00
f[39] -1.27 0.16 -1.27 -1.53 -1.00 1522.21 1.00
f[40] -1.31 0.16 -1.31 -1.57 -1.03 1543.47 1.00
f[41] -1.33 0.16 -1.33 -1.61 -1.07 1565.48 1.00
f[42] -1.33 0.16 -1.33 -1.60 -1.07 1593.17 1.00
f[43] -1.31 0.16 -1.31 -1.57 -1.04 1618.24 1.00
f[44] -1.27 0.16 -1.27 -1.53 -1.00 1653.18 1.00
f[45] -1.21 0.17 -1.21 -1.48 -0.94 1696.87 1.00
f[46] -1.13 0.18 -1.13 -1.43 -0.85 1769.31 1.00
f[47] -1.03 0.20 -1.02 -1.35 -0.71 1903.47 1.00
f[48] -0.92 0.22 -0.91 -1.29 -0.56 2086.05 1.00
f[49] -0.79 0.26 -0.79 -1.24 -0.39 2301.05 1.00
Number of divergences: 0
# calculate mean and standard deviation
f_mean = jnp.mean(f_posterior, axis=0)
f_hpdi = hpdi(f_posterior, 0.95)
plt.figure(figsize=(8, 6))
plt.plot(x, f_true, label='True Function', color='purple')
plt.scatter(x, y_obs, color='orangered', label='Noisy Data')
plt.plot(x, f_mean, color='teal', label='Mean Prediction')
plt.fill_between(x.squeeze(), f_hpdi[0], f_hpdi[1], color='teal', alpha=0.3, label='HPDI') # uncertainty bounds
plt.grid()
plt.legend()
plt.show()
This looks good. But in reality we barely ever know true level of noise.
Inference with Numpyro: estimating noise#
def model(x, y=None, kernel_func=rbf_kernel, lengthcsale=0.2, jitter=1e-5):
"""
args:
- x (jax.numpy.ndarray): input data points of shape (n, d), where n is the number of points and d is the number of dimensions
- kernel_func (function): kernel function
- lengthscale (float): length-scale parameter
- jitter (float): jitter for numerical stability
returns:
- numpyro.sample: a sample from the Multivariate Normal distribution representing the function values at input points
"""
n = x.shape[0]
K = kernel_func(x, x, lengthcsale) + jitter*jnp.eye(n)
f = numpyro.sample("f", dist.MultivariateNormal(jnp.zeros(n), covariance_matrix=K))
sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
numpyro.sample("y", dist.Normal(f, sigma), obs=y)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=2000, num_chains=2, chain_method='parallel', progress_bar=False)
mcmc.run(jax.random.PRNGKey(42), x, y_obs)
# Print summary statistics of posterior
mcmc.print_summary()
# Get the posterior samples
posterior_samples = mcmc.get_samples()
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[21], line 2
1 # Print summary statistics of posterior
----> 2 mcmc.print_summary()
4 # Get the posterior samples
5 posterior_samples = mcmc.get_samples()
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:783, in MCMC.print_summary(self, prob, exclude_deterministic)
777 if isinstance(state_sample_field, dict):
778 sites = {
779 k: v
780 for k, v in self._states[self._sample_field].items()
781 if k in state_sample_field
782 }
--> 783 print_summary(sites, prob=prob)
784 extra_fields = self.get_extra_fields()
785 if "diverging" in extra_fields:
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:312, in print_summary(samples, prob, group_by_chain)
308 if not isinstance(samples, dict):
309 samples = {
310 "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
311 }
--> 312 summary_dict = summary(samples, prob, group_by_chain=True)
313 if not summary_dict:
314 return
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/numpyro/diagnostics.py:263, in summary(samples, prob, group_by_chain)
261 if len(value) == 0:
262 continue
--> 263 value = device_get(value)
264 value_flat = np.reshape(value, (-1,) + value.shape[2:])
265 mean = value_flat.mean(axis=0)
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/api.py:2494, in device_get(x)
2492 except AttributeError:
2493 pass
-> 2494 return tree_map(_device_get, x)
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:359, in tree_map(f, tree, is_leaf, *rest)
357 leaves, treedef = tree_flatten(tree, is_leaf)
358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/tree_util.py:359, in <genexpr>(.0)
357 leaves, treedef = tree_flatten(tree, is_leaf)
358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/api.py:2455, in _device_get(x)
2453 return x
2454 else:
-> 2455 return toarray()
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/array.py:430, in ArrayImpl.__array__(self, dtype, context, copy)
427 def __array__(self, dtype=None, context=None, copy=None):
428 # copy argument is supported by np.asarray starting in numpy 2.0
429 kwds = {} if copy is None else {'copy': copy}
--> 430 return np.asarray(self._value, dtype=dtype, **kwds)
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
330 @wraps(func)
331 def wrapper(*args, **kwargs):
332 with TraceAnnotation(name, **decorator_kwargs):
--> 333 return func(*args, **kwargs)
334 return wrapper
File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/jax/_src/array.py:650, in ArrayImpl._value(self)
648 npy_value = np.empty(self.shape, self.dtype)
649 for i, ind in _cached_index_calc(self.sharding, self.shape):
--> 650 npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
651 self._npy_value = npy_value
652 self._npy_value.flags.writeable = False
KeyboardInterrupt:
sigma_posterior = posterior_samples['sigma']
plt.figure(figsize=(6, 4))
sns.kdeplot(sigma_posterior, fill=True)
plt.axvline(x=0.5, color='r', linestyle='--', label='True value')
plt.xlabel('$\sigma$')
plt.ylabel('Density')
plt.title('Posterior of parameter $\sigma$')
# Show the plot
plt.show()
We have inferred the variance parameter successfully. Estimating lengthscale, especially for less smooth kernels is harder. One more issue is the non-identifiability of the pair lengthscale-variance. Hence, if they both really need to be inferred, strong priors would be beneficial.