Bayesian workflow#

Leveraging Bayesian inference for addressing real-world problems requires from the modeller not only to be proficient in statsitics, have domain expertise and programming skills, but also a deep understanding of the decision-making process while analysing data. Apart from inference, the workflow includes iterative model building, model checking, validation and troubleshooting of computational problems, model understanding, and model comparison.

Seemingly, the Bayes rule looks very simple:

\[\underbrace{p(\theta|y)}_\text{posterior} \propto \underbrace{p(y | \theta)}_{\text{likelihood}} \underbrace{p(\theta)}_{\text{prior}}\]

What could possibly go wrong about it in practice?

A lot can go wrong! And in case things go wrong, decisions need to be made sequentially about model building and improvement. That is why we need the Bayesian workflow.

Principles of Bayesian workflow#

Workflows exist in a variety of disciplines where they define what is a ‘good practice’.

Box’s loop#

In the 1960’s, the statistician Box formulated the notion of a loop to understand the nature of the scientific method. This loop is called Box’s loop by Blei et. al. (2014):

Modern Bayesian workflow#

A systematic review of the steps within the modern Bayesian workflow, described in [Gelman et al., 2020]:

Prior predictive checks#

Prior predictive checking consists in simulating data from the priors. Then such simulations are commonly visualized (especially when transformations of parameters is involved). This shows the range of data compatible with the model, helps understand the adequacy of the chosen priors, as it is often easier to elicit expert knowledge on measureable quantities of interest rather than abstract parameter values.

Iterative model building#

A possible realisation of the Bayeisan workflow loop:

  • Understand the domain and problem,

  • Formulate the model mathematically,

  • Implement model, test, debug,

  • debug, debug, debug.

  • Perform prior predictive, check,

  • Fit the model,

  • Assess convergence diagnostics,

  • Perform posterior predictive check,

  • Improve the model iteratively: from baseline to complex and computationally efficient models.

Examples#

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.diagnostics import hpdi

import jax
import jax.numpy as jnp
from jax import random

import arviz as az
from scipy.stats import gaussian_kde

import matplotlib.pyplot as plt
import pandas as pd
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Coin tossing#

Data#

n = 100    # number of trials
h = 61     # number of successes
alpha = 2  # hyperparameters
beta = 2

niter = 1000

Model#

def model(n, alpha=2, beta=2, h=None):

    # prior on the probability of success p
    p = numpyro.sample('p', dist.Beta(alpha, beta))

    # likelihood - notice the `obs=h` part
    # p is the probabiity of success,
    # n is the total number of experiments
    # h is the number of successes
    numpyro.sample('obs', dist.Binomial(n, p), obs=h)

Prior Predictive check#

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# use the Predictive class to generate predictions.
# Notice that we are not passing observation `h` as data.
# Since we have set `h=None`, this allows the model to make predictions of `h`
# when data for it is not provided.
prior_predictive = Predictive(model, num_samples=1000)
prior_predictions = prior_predictive(rng_key_, n)
# we have generated samples for two variables
prior_predictions.keys()
dict_keys(['obs', 'p'])
# extract samples for variable 'p'
pred_obs = prior_predictions['p']

# compute its summary statistics for the samples of `p`
mean_prior_pred = jnp.mean(pred_obs, axis=0)
hpdi_prior_pred = hpdi(pred_obs, 0.89)
fig = plt.figure(dpi=100, figsize=(5, 3))
plt.hist(pred_obs, bins=15, density=True, alpha=0.5, color='teal')
x = jnp.linspace(0, 1, 3000)
kde = gaussian_kde(pred_obs)
plt.plot(x, kde(x), color='teal', lw=3, alpha=0.5)
plt.xlabel('p')
plt.title('Prior predictive distribution for $p$')
plt.xlim(0, 1)
plt.grid(0.3)
plt.show()
_images/bd20d2da6aa2b5817073e913a197079eea53f8ae103f5d82d1080111072794cd.png

Inference#

Using the same routine as we did for prior predictive, we can perform inference by using the observed data.

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# specify inference algorithm
kernel = NUTS(model)

# define number of samples and number chains
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
/tmp/ipykernel_2401/353969094.py:8: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
#run MCMC
mcmc.run(rng_key_, n=n, h=h)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 2
      1 #run MCMC
----> 2 mcmc.run(rng_key_, n=n, h=h)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:686, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    684 else:
    685     if self.chain_method == "sequential":
--> 686         states, last_state = _laxmap(partial_map_fn, map_args)
    687     elif self.chain_method == "parallel":
    688         states, last_state = pmap(partial_map_fn)(map_args)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:176, in _laxmap(f, xs)
    174 for i in range(n):
    175     x = jit(_get_value_from_index)(xs, i)
--> 176     ys.append(f(x))
    178 return jax.tree.map(lambda *args: jnp.stack(args), *ys)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:443, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    441 # Check if _sample_fn is None, then we need to initialize the sampler.
    442 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 443     new_init_state = self.sampler.init(
    444         rng_key,
    445         self.num_warmup,
    446         init_params,
    447         model_args=args,
    448         model_kwargs=kwargs,
    449     )
    450     init_state = new_init_state if init_state is None else init_state
    451 sample_fn, postprocess_fn = self._get_cached_fns()

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:749, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    744 # vectorized
    745 else:
    746     rng_key, rng_key_init_model = jnp.swapaxes(
    747         vmap(random.split)(rng_key), 0, 1
    748     )
--> 749 init_params = self._init_state(
    750     rng_key_init_model, model_args, model_kwargs, init_params
    751 )
    752 if self._potential_fn and init_params is None:
    753     raise ValueError(
    754         "Valid value of `init_params` must be provided with" " `potential_fn`."
    755     )

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:693, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    686 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    687     if self._model is not None:
    688         (
    689             new_init_params,
    690             potential_fn,
    691             postprocess_fn,
    692             model_trace,
--> 693         ) = initialize_model(
    694             rng_key,
    695             self._model,
    696             dynamic_args=True,
    697             init_strategy=self._init_strategy,
    698             model_args=model_args,
    699             model_kwargs=model_kwargs,
    700             forward_mode_differentiation=self._forward_mode_differentiation,
    701         )
    702         if init_params is None:
    703             init_params = new_init_params

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:712, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    710     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    711 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 712 (init_params, pe, grad), is_valid = find_valid_initial_params(
    713     rng_key,
    714     substitute(
    715         model,
    716         data={
    717             k: site["value"]
    718             for k, site in model_trace.items()
    719             if site["type"] in ["plate"]
    720         },
    721     ),
    722     init_strategy=init_strategy,
    723     enum=has_enumerate_support,
    724     model_args=model_args,
    725     model_kwargs=model_kwargs,
    726     prototype_params=prototype_params,
    727     forward_mode_differentiation=forward_mode_differentiation,
    728     validate_grad=validate_grad,
    729 )
    731 if not_jax_tracer(is_valid):
    732     if device_get(~jnp.all(is_valid)):

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:446, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    444 # Handle possible vectorization
    445 if is_prng_key(rng_key):
--> 446     (init_params, pe, z_grad), is_valid = _find_valid_params(
    447         rng_key, exit_early=True
    448     )
    449 else:
    450     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:432, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    428 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
    429 if exit_early and not_jax_tracer(rng_key):
    430     # Early return if valid params found. This is only helpful for single chain,
    431     # where we can avoid compiling body_fn in while_loop.
--> 432     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    433     if not_jax_tracer(is_valid):
    434         if device_get(is_valid):

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:416, in find_valid_initial_params.<locals>.body_fn(state)
    414     z_grad = jacfwd(potential_fn)(params)
    415 else:
--> 416     pe, z_grad = value_and_grad(potential_fn)(params)
    417 z_grad_flat = ravel_pytree(z_grad)[0]
    418 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:298, in potential_energy(model, model_args, model_kwargs, params, enum)
    294 substituted_model = substitute(
    295     model, substitute_fn=partial(_unconstrain_reparam, params)
    296 )
    297 # no param is needed for log_density computation because we already substitute
--> 298 log_joint, model_trace = log_density_(
    299     substituted_model, model_args, model_kwargs, {}
    300 )
    301 return -log_joint

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:91, in log_density(model, model_args, model_kwargs, params)
     85     except ValueError:
     86         raise ValueError(
     87             "Model and guide shapes disagree at site: '{}': {} vs {}".format(
     88                 site["name"], model_shape, guide_shape
     89             )
     90         )
---> 91     log_prob = site["fn"].log_prob(value)
     93 if (scale is not None) and (not is_identically_one(scale)):
     94     log_prob = scale * log_prob

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/util.py:706, in validate_sample.<locals>.wrapper(self, *args, **kwargs)
    705 def wrapper(self, *args, **kwargs):
--> 706     log_prob = log_prob_fn(self, *args, **kwargs)
    707     if self._validate_args:
    708         value = kwargs["value"] if "value" in kwargs else args[0]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/discrete.py:201, in BinomialProbs.log_prob(self, value)
    195 log_factorial_nmk = gammaln(self.total_count - value + 1)
    196 probs = clamp_probs(self.probs)
    197 return (
    198     log_factorial_n
    199     - log_factorial_k
    200     - log_factorial_nmk
--> 201     + xlogy(value, probs)
    202     + xlog1py(self.total_count - value, -probs)
    203 )

    [... skipping hidden 5 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/scipy/special.py:504, in _xlogy_jvp(primals, tangents)
    502 (x_dot, y_dot) = tangents
    503 result = xlogy(x, y)
--> 504 return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    571 args = (other, self) if swap else (self, other)
    572 if isinstance(other, _accepted_binop_types):
--> 573   return binary_op(*args)
    574 # Note: don't use isinstance here, because we don't want to raise for
    575 # subclasses, e.g. NamedTuple objects that may override operators.
    576 if type(other) in _rejected_binop_types:

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
    175   raise NotImplementedError(f"where argument of {self}")
    176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)

    [... skipping hidden 11 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
   1115 @partial(jit, inline=True)
   1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
   1117   """Multiply two arrays element-wise.
   1118 
   1119   JAX implementation of :obj:`numpy.multiply`. This is a universal function,
   (...)
   1140     Array([ 0, 10, 20, 30], dtype=int32)
   1141   """
-> 1142   x, y = promote_args("multiply", x, y)
   1143   return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
    353 """Convenience function to apply Numpy argument shape and dtype promotion."""
    354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
    356 check_for_prngkeys(fun_name, *args)
    357 return promote_shapes(fun_name, *promote_dtypes(*args))

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
    324 """Check if none of the args have dtype float0."""
    325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326   raise TypeError(
    327       f"Called {fun_name} with a float0 array. "
    328       "float0s do not support any operations by design because they "
    329       "are not compatible with non-trivial vector spaces. No implicit dtype "
    330       "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
    331       "to cast a float0 array to a regular zeros array. \n"
    332       "If you didn't expect to get a float0 you might have accidentally "
    333       "taken a gradient with respect to an integer argument.")

TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
# extract samples of parameter p
p_samples = mcmc.get_samples()
p_posterior_samples = p_samples['p']
fig = plt.figure(dpi=100, figsize=(5, 3))
plt.hist(pred_obs, bins=15, density=True, alpha=0.5, color='teal', label = "Prior distribution")
plt.hist(p_posterior_samples, bins=15, density=True, alpha=0.5, color='orangered', label = "Posterior distribution")
x = jnp.linspace(0, 1, 3000)
kde = gaussian_kde(pred_obs)
plt.plot(x, kde(x), color='teal', lw=3, alpha=0.5)
kde = gaussian_kde(p_posterior_samples)
plt.plot(x, kde(x), color='orangered', lw=3, alpha=0.5)
plt.xlabel('p')
plt.xlim(0, 1)
plt.grid(0.3)
plt.legend()
plt.show()
_images/22c973d7d9a247015aeac450b537213e7d7ba9522f4a819d3dd6827bfdff46a0.png

Check convergence#

We now have obtained the samples from MCMC. How can we assess whether we can trust the results? Convergence diganostics survey this purpose. Beyond \(\hat{R}\), we can also visually inspect traceplots. Traceplots are simply sample values plotted against the iteration number. We want those traceplots to be stationary, i.e. they should look like a “hairy carterpillar”.

# inpect summary
# pring summary and look at R-hat
# r_hat is a dignostic comparing within chain variation to between chan variation.
# It is an importnat convergene diagnostic, and we want its valye to be close to 1
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         p      0.61      0.05      0.61      0.53      0.68   3059.08      1.00

Number of divergences: 0
# plot posterior distribution and traceplots
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
_images/8a2dad89eeb01c7f336aedb02b00adab0d3214e6d6c472ac441dcbf1e152171f.png

Posterior predictive distribution#

The posterior predictive distribution is a concept in Bayesian statistics that combines the information from both the observed data and the posterior distribution of model parameters to generate predictions for new, unseen data .

We can use the obtained samples obtained at the previous step to generate posterior predictive distribution on the outcome.

# using the same 'Predictive' class,
# but now specifying also `p_samples`
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model, p_samples)
posterior_predictions = predictive(rng_key_, n=n)
# extract prediction and calculate summary statistics
post_obs = posterior_predictions['obs']
mean_post_pred = jnp.mean(post_obs, axis=0)
hpdi_post_pred = hpdi(post_obs, 0.9)
# what is the mean number of successes?
mean_post_pred
Array(60.508003, dtype=float32)
# what is the unceratinty around this mean?
hpdi_post_pred
array([48, 70], dtype=int32)

Group Task

  • Change the hyperparameters of the model. How are they changing the results?

  • Change the data to n=10, h=6. How does the result change now?

Bayesian Linear regression#

Now that we know how to use NumPyro. Let us build an example using larger amounts of data and build a Bayesian Linear Regression model. It is the same Linear Regression model you are familiar with, but here all of the parameters are estimated in the Bayesian way.

!wget -O Howell1.csv https://raw.githubusercontent.com/deep-learning-indaba/indaba-pracs-2023/main/data/Howell1.csv

df = pd.read_csv('Howell1.csv', sep=";")
df.head()
--2024-03-19 22:17:00--  https://raw.githubusercontent.com/deep-learning-indaba/indaba-pracs-2023/main/data/Howell1.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12205 (12K) [text/plain]
Saving to: ‘Howell1.csv’

Howell1.csv         100%[===================>]  11.92K  --.-KB/s    in 0.002s  

2024-03-19 22:17:00 (5.75 MB/s) - ‘Howell1.csv’ saved [12205/12205]
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041914 41.0 1
4 145.415 41.276872 51.0 0
# observed data
weight = df.weight.values
height = df.height.values

Let us define some test data for the variable weight. For these datapoints we will make predictions.

# data to make predictions for
weight_pred = jnp.array([45, 40, 65, 31, 53])
# plot the data
plt.figure(figsize=(6, 4))
plt.scatter(x='weight', y='height', data=df, color='teal')
plt.grid(0.3)
_images/b92e887b33ad4fcf7c70298c3d44a86321c7dc40859ac69f817c1b32b9159e57.png

The linear regression model will have the form

\[\begin{split}y \sim N(\mu, \sigma^2),\\ \mu = b_0 + b_1 x.\end{split}\]

Here \(y\) is the data we want to predict, \(x\) is the predictor, \(b_0\) is the bias (intercept), \(b_1\) is the slope (weight) and \(\sigma^2\) is variance.

Group task

Discuss which priors would be reasonable for the parameters \(b_0\), \(b_1\), \(\sigma\).

# model
def model(weight=None, height=None):
    # priors
    b0 = numpyro.sample('b0', dist.Normal(120,50))
    b1 = numpyro.sample('b1', dist.Normal(0,1))
    sigma = numpyro.sample('sigma', dist.HalfNormal(10.))

    # deterministic transformation
    mu = b0 + b1 * weight

    # likelihood: notice `obs=height`
    numpyro.sample('obs', dist.Normal(mu, sigma), obs=height)
# prior predictive
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
prior_predictive = Predictive(model, num_samples=100)
prior_predictions = prior_predictive(rng_key_, weight)
prior_predictions.keys()
dict_keys(['b0', 'b1', 'obs', 'sigma'])
pred_obs = prior_predictions['obs']
mean_prior_pred = jnp.mean(pred_obs, axis=0)
hpdi_prior_pred = hpdi(pred_obs, 0.89)
def plot_regression(x, y_mean, y_hpdi, height, ttl='Predictions with 89% CI)'):
    # Sort values for plotting by x axis
    idx = jnp.argsort(x)
    weight = x[idx]
    mean = y_mean[idx]
    hpdi = y_hpdi[:, idx]
    ht = height[idx]

    # Plot
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
    ax.plot(weight, mean, color='teal')
    ax.plot(weight, ht, 'o', color='orangered')
    ax.fill_between(weight, hpdi[0], hpdi[1], alpha=0.3, interpolate=True, color='teal')
    ax.set(xlabel='weight', ylabel='height', title=ttl);
    return ax
ax = plot_regression(weight, mean_prior_pred, hpdi_prior_pred, height, ttl="Prior predictive")
_images/af3a28a578dcb372dbdd8f3c46d68f0d2499713877dc2161cb164ac990c0db2b.png

Group Task

Change the prior for \(b_0\) to dist.Normal(120,50) and visualise the prior predictive distribution. Does it look satisfactory? Do you think we can go ahead with inference using this set of priors?

# Inference
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
mcmc.run(rng_key_, weight=weight, height=height)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
<ipython-input-28-8147f48aa0a5>:7: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        b0     75.49      1.07     75.48     73.72     77.26   2608.93      1.00
        b1      1.76      0.03      1.76      1.72      1.81   2692.35      1.00
     sigma      9.37      0.29      9.36      8.86      9.82   3778.49      1.00

Number of divergences: 0
# Check convergence
mcmc.print_summary()
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        b0     75.49      1.07     75.48     73.72     77.26   2608.93      1.00
        b1      1.76      0.03      1.76      1.72      1.81   2692.35      1.00
     sigma      9.37      0.29      9.36      8.86      9.82   3778.49      1.00

Number of divergences: 0
_images/11cbf087417a9fecf3a35744f7b5a0882121dc1d15ddde6e06f7bc65c9218d27.png
# Posterior predictive
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model, samples_1)
posterior_predictions = predictive(rng_key_, weight=weight)
post_obs = posterior_predictions['obs']

mean_post_pred = jnp.mean(post_obs, axis=0)
hpdi_post_pred = hpdi(post_obs, 0.9)

ax = plot_regression(weight, mean_post_pred, hpdi_post_pred, height, ttl="Posterior predictive")
ax.set(xlabel='weight (scaled)', ylabel='height (scaled)');
_images/2b3290bab32b266595f68e93a4ec55e7348709a52453c8a7290d8e5a6b28bd00.png
# predict for new data
predictive = Predictive(model, samples_1)
predictions = predictive(rng_key_, weight=weight_pred)['obs']

mean_pred = jnp.mean(predictions, axis=0)
hpdi_pred = hpdi(predictions, 0.89)

d = {'weight_pred': weight_pred, 'mean_pred': mean_pred, 'lower': hpdi_pred[0,], 'upper': hpdi_pred[1,]}
df_res = pd.DataFrame(data=d)
df_res.head()
weight_pred mean_pred lower upper
0 45 154.772690 139.393387 169.614548
1 40 146.176010 131.200302 161.135223
2 65 190.045502 175.474350 204.991074
3 31 130.221954 115.344894 145.279709
4 53 168.721436 154.352890 184.339874

Task 18

Modify the model so that it fits better.

Hint: apply a transformation to input data, e.g. a polynomial.

For this model,

  • plot prior predictive distribution,

  • perform inference,

  • plot posterior predictive distribution.

Why use Bayesian regression over frequentist?#

  • Uncertainty estimation: Bayesian regression provides a natural way to quantify uncertainty in the estimates of the model parameters. This can be useful for making probabilistic predictions and for decision-making under uncertainty. Frequentist methods typically provide point estimates and, in the best case, confidence intervals (CIs). Here we get Bayesian credible intervals (BCIs).

  • Handling small sample sizes: Bayesian methods can be particularly useful when dealing with small sample sizes, as they allow for more flexible modeling assumptions and can incorporate prior information to help stabilize estimates. Frequentist methods may struggle with small sample sizes, leading to unreliable estimates and inflated uncertainty.

  • Regularization: Bayesian regression naturally incorporates regularization techniques, such as priors that penalize large parameter values, which can help prevent overfitting and improve the generalization performance of the model. While frequentist methods also offer regularization techniques, such as Ridge regression and Lasso regression, the Bayesian framework provides a principled way to specify and interpret regularization.

Let’s unpack the regularization argument a bit more.

The linear regression model is given by:

\[\begin{split} y_i = \beta_0 + \beta_1 x_i + \epsilon_i, \quad \epsilon \sim \mathcal{N}(0, \sigma^2)\\ \end{split}\]

where \(y_i\) is the response variable, \(x_i\) is the predictor variable, \(\beta_0\) and \(\beta_1\) are the regression coefficients, and \(\epsilon_i\) are idependent and normally distributed errors.

We assume normal priors for \(\beta_0\) and \(\beta_1\):

\[\begin{split} \beta_0 \sim \mathcal{N}(\mu_0, \tau_0^2),\\ \beta_1 \sim \mathcal{N}(\mu_1, \tau_1^2) \end{split}\]

Here, \(\alpha_0\), \(\alpha_1\), \(\tau_0^2\), and \(\tau_1^2\) are hyperparameters that need to be specified. The choice of these hyperparameters can influence the strength of regularization.

The likelihood function, assuming normal errors, is:

\[ p(y |\beta_0, \beta_1, \sigma^2, x) = \prod_{i=1}^{n} \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(y_i - \beta_0 - \beta_1 x_i)^2}{2\sigma^2}\right) \]

And the priors: $\( p(\beta_0) = \frac{1}{\sqrt{2\pi\tau_0^2}} \exp\left(-\frac{(\beta_0 - \mu_0)^2}{2\tau_0^2}\right) \)\( \)\( p(\beta_1) = \frac{1}{\sqrt{2\pi\tau_1^2}} \exp\left(-\frac{(\beta_1 - \mu_1)^2}{2\tau_1^2}\right) \)$

Applying Bayes’ theorem, we get the unnormalized posterior:

\[\begin{split} \begin{align*} p(\beta_0, \beta_1 | y, x) &\propto p(y|\beta_0, \beta_1, \sigma^2) \times p(\beta_0) \times p(\beta_1) \\ &\propto \prod_{i=1}^{n} \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(y_i - \beta_0 - \beta_1 x_i)^2}{2\sigma^2}\right) \times \frac{1}{\sqrt{2\pi\tau_0^2}} \exp\left(-\frac{(\beta_0 - \mu_0)^2}{2\tau_0^2}\right) \times \frac{1}{\sqrt{2\pi\tau_1^2}} \exp\left(-\frac{(\beta_1 - \mu_1)^2}{2\tau_1^2}\right) \end{align*} \end{split}\]

Computing the log-posterior, we get

\[ \begin{align*} \log p(\beta_0, \beta_1 | Y, X) &\propto \underbrace{\sum_{i=1}^{n} \left(-\frac{(y_i - \beta_0 - \beta_1 x_i)^2}{2\sigma^2}\right)}_{\text{MSE}} - \underbrace{\left(\frac{(\beta_0 - \mu_0)^2}{2\tau_0^2} + \frac{(\beta_1 - \mu_1)^2}{2\tau_1^2}\right)}_{\text{regulalisation}} + \text{constants} \end{align*} \]

Here the first term is the mean squared error (MSE) which usually serves as an objective in frequentist machine learning, and the second term proved regularization.

Regularization is controlled through the choice of hyperparameters \(\tau_0^2\) and \(\tau_1^2\). Larger values of \(\tau_0^2\) and \(\tau_1^2\) result in stronger regularization, pushing the estimated coefficients towards zero and simplifying the model.

[GVS+20]

Andrew Gelman, Aki Vehtari, Daniel Simpson, Charles C Margossian, Bob Carpenter, Yuling Yao, Lauren Kennedy, Jonah Gabry, Paul-Christian Bürkner, and Martin Modrák. Bayesian workflow. arXiv preprint arXiv:2011.01808, 2020.