Bayesian workflow#
Leveraging Bayesian inference for addressing real-world problems requires from the modeller not only to be proficient in statsitics, have domain expertise and programming skills, but also a deep understanding of the decision-making process while analysing data. Apart from inference, the workflow includes iterative model building, model checking, validation and troubleshooting of computational problems, model understanding, and model comparison.
Seemingly, the Bayes rule looks very simple:
What could possibly go wrong about it in practice?
A lot can go wrong! And in case things go wrong, decisions need to be made sequentially about model building and improvement. That is why we need the Bayesian workflow.
Principles of Bayesian workflow#
Workflows exist in a variety of disciplines where they define what is a ‘good practice’.
Box’s loop#
In the 1960’s, the statistician Box formulated the notion of a loop to understand the nature of the scientific method. This loop is called Box’s loop by Blei et. al. (2014):
Modern Bayesian workflow#
A systematic review of the steps within the modern Bayesian workflow, described in [Gelman et al., 2020]:
Prior predictive checks#
Prior predictive checking consists in simulating data from the priors. Then such simulations are commonly visualized (especially when transformations of parameters is involved). This shows the range of data compatible with the model, helps understand the adequacy of the chosen priors, as it is often easier to elicit expert knowledge on measureable quantities of interest rather than abstract parameter values.
Iterative model building#
A possible realisation of the Bayeisan workflow loop:
Understand the domain and problem,
Formulate the model mathematically,
Implement model, test, debug,
debug, debug, debug.
Perform prior predictive, check,
Fit the model,
Assess convergence diagnostics,
Perform posterior predictive check,
Improve the model iteratively: from baseline to complex and computationally efficient models.
Examples#
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.diagnostics import hpdi
import jax
import jax.numpy as jnp
from jax import random
import arviz as az
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import pandas as pd
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Coin tossing#
Data#
n = 100 # number of trials
h = 61 # number of successes
alpha = 2 # hyperparameters
beta = 2
niter = 1000
Model#
def model(n, alpha=2, beta=2, h=None):
# prior on the probability of success p
p = numpyro.sample('p', dist.Beta(alpha, beta))
# likelihood - notice the `obs=h` part
# p is the probabiity of success,
# n is the total number of experiments
# h is the number of successes
numpyro.sample('obs', dist.Binomial(n, p), obs=h)
Prior Predictive check#
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
# use the Predictive class to generate predictions.
# Notice that we are not passing observation `h` as data.
# Since we have set `h=None`, this allows the model to make predictions of `h`
# when data for it is not provided.
prior_predictive = Predictive(model, num_samples=1000)
prior_predictions = prior_predictive(rng_key_, n)
# we have generated samples for two variables
prior_predictions.keys()
dict_keys(['obs', 'p'])
# extract samples for variable 'p'
pred_obs = prior_predictions['p']
# compute its summary statistics for the samples of `p`
mean_prior_pred = jnp.mean(pred_obs, axis=0)
hpdi_prior_pred = hpdi(pred_obs, 0.89)
fig = plt.figure(dpi=100, figsize=(5, 3))
plt.hist(pred_obs, bins=15, density=True, alpha=0.5, color='teal')
x = jnp.linspace(0, 1, 3000)
kde = gaussian_kde(pred_obs)
plt.plot(x, kde(x), color='teal', lw=3, alpha=0.5)
plt.xlabel('p')
plt.title('Prior predictive distribution for $p$')
plt.xlim(0, 1)
plt.grid(0.3)
plt.show()
Inference#
Using the same routine as we did for prior predictive, we can perform inference by using the observed data.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
# specify inference algorithm
kernel = NUTS(model)
# define number of samples and number chains
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
/tmp/ipykernel_2401/353969094.py:8: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
#run MCMC
mcmc.run(rng_key_, n=n, h=h)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 2
1 #run MCMC
----> 2 mcmc.run(rng_key_, n=n, h=h)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:686, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
684 else:
685 if self.chain_method == "sequential":
--> 686 states, last_state = _laxmap(partial_map_fn, map_args)
687 elif self.chain_method == "parallel":
688 states, last_state = pmap(partial_map_fn)(map_args)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:176, in _laxmap(f, xs)
174 for i in range(n):
175 x = jit(_get_value_from_index)(xs, i)
--> 176 ys.append(f(x))
178 return jax.tree.map(lambda *args: jnp.stack(args), *ys)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:443, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
441 # Check if _sample_fn is None, then we need to initialize the sampler.
442 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 443 new_init_state = self.sampler.init(
444 rng_key,
445 self.num_warmup,
446 init_params,
447 model_args=args,
448 model_kwargs=kwargs,
449 )
450 init_state = new_init_state if init_state is None else init_state
451 sample_fn, postprocess_fn = self._get_cached_fns()
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:749, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
744 # vectorized
745 else:
746 rng_key, rng_key_init_model = jnp.swapaxes(
747 vmap(random.split)(rng_key), 0, 1
748 )
--> 749 init_params = self._init_state(
750 rng_key_init_model, model_args, model_kwargs, init_params
751 )
752 if self._potential_fn and init_params is None:
753 raise ValueError(
754 "Valid value of `init_params` must be provided with" " `potential_fn`."
755 )
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:693, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
686 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
687 if self._model is not None:
688 (
689 new_init_params,
690 potential_fn,
691 postprocess_fn,
692 model_trace,
--> 693 ) = initialize_model(
694 rng_key,
695 self._model,
696 dynamic_args=True,
697 init_strategy=self._init_strategy,
698 model_args=model_args,
699 model_kwargs=model_kwargs,
700 forward_mode_differentiation=self._forward_mode_differentiation,
701 )
702 if init_params is None:
703 init_params = new_init_params
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:712, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
710 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
711 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 712 (init_params, pe, grad), is_valid = find_valid_initial_params(
713 rng_key,
714 substitute(
715 model,
716 data={
717 k: site["value"]
718 for k, site in model_trace.items()
719 if site["type"] in ["plate"]
720 },
721 ),
722 init_strategy=init_strategy,
723 enum=has_enumerate_support,
724 model_args=model_args,
725 model_kwargs=model_kwargs,
726 prototype_params=prototype_params,
727 forward_mode_differentiation=forward_mode_differentiation,
728 validate_grad=validate_grad,
729 )
731 if not_jax_tracer(is_valid):
732 if device_get(~jnp.all(is_valid)):
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:446, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
444 # Handle possible vectorization
445 if is_prng_key(rng_key):
--> 446 (init_params, pe, z_grad), is_valid = _find_valid_params(
447 rng_key, exit_early=True
448 )
449 else:
450 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:432, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
428 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
429 if exit_early and not_jax_tracer(rng_key):
430 # Early return if valid params found. This is only helpful for single chain,
431 # where we can avoid compiling body_fn in while_loop.
--> 432 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
433 if not_jax_tracer(is_valid):
434 if device_get(is_valid):
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:416, in find_valid_initial_params.<locals>.body_fn(state)
414 z_grad = jacfwd(potential_fn)(params)
415 else:
--> 416 pe, z_grad = value_and_grad(potential_fn)(params)
417 z_grad_flat = ravel_pytree(z_grad)[0]
418 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:298, in potential_energy(model, model_args, model_kwargs, params, enum)
294 substituted_model = substitute(
295 model, substitute_fn=partial(_unconstrain_reparam, params)
296 )
297 # no param is needed for log_density computation because we already substitute
--> 298 log_joint, model_trace = log_density_(
299 substituted_model, model_args, model_kwargs, {}
300 )
301 return -log_joint
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:91, in log_density(model, model_args, model_kwargs, params)
85 except ValueError:
86 raise ValueError(
87 "Model and guide shapes disagree at site: '{}': {} vs {}".format(
88 site["name"], model_shape, guide_shape
89 )
90 )
---> 91 log_prob = site["fn"].log_prob(value)
93 if (scale is not None) and (not is_identically_one(scale)):
94 log_prob = scale * log_prob
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/util.py:706, in validate_sample.<locals>.wrapper(self, *args, **kwargs)
705 def wrapper(self, *args, **kwargs):
--> 706 log_prob = log_prob_fn(self, *args, **kwargs)
707 if self._validate_args:
708 value = kwargs["value"] if "value" in kwargs else args[0]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/discrete.py:201, in BinomialProbs.log_prob(self, value)
195 log_factorial_nmk = gammaln(self.total_count - value + 1)
196 probs = clamp_probs(self.probs)
197 return (
198 log_factorial_n
199 - log_factorial_k
200 - log_factorial_nmk
--> 201 + xlogy(value, probs)
202 + xlog1py(self.total_count - value, -probs)
203 )
[... skipping hidden 5 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/scipy/special.py:504, in _xlogy_jvp(primals, tangents)
502 (x_dot, y_dot) = tangents
503 result = xlogy(x, y)
--> 504 return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
571 args = (other, self) if swap else (self, other)
572 if isinstance(other, _accepted_binop_types):
--> 573 return binary_op(*args)
574 # Note: don't use isinstance here, because we don't want to raise for
575 # subclasses, e.g. NamedTuple objects that may override operators.
576 if type(other) in _rejected_binop_types:
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
175 raise NotImplementedError(f"where argument of {self}")
176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)
[... skipping hidden 11 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
1115 @partial(jit, inline=True)
1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
1117 """Multiply two arrays element-wise.
1118
1119 JAX implementation of :obj:`numpy.multiply`. This is a universal function,
(...)
1140 Array([ 0, 10, 20, 30], dtype=int32)
1141 """
-> 1142 x, y = promote_args("multiply", x, y)
1143 return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
353 """Convenience function to apply Numpy argument shape and dtype promotion."""
354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
356 check_for_prngkeys(fun_name, *args)
357 return promote_shapes(fun_name, *promote_dtypes(*args))
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
324 """Check if none of the args have dtype float0."""
325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326 raise TypeError(
327 f"Called {fun_name} with a float0 array. "
328 "float0s do not support any operations by design because they "
329 "are not compatible with non-trivial vector spaces. No implicit dtype "
330 "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
331 "to cast a float0 array to a regular zeros array. \n"
332 "If you didn't expect to get a float0 you might have accidentally "
333 "taken a gradient with respect to an integer argument.")
TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
# extract samples of parameter p
p_samples = mcmc.get_samples()
p_posterior_samples = p_samples['p']
fig = plt.figure(dpi=100, figsize=(5, 3))
plt.hist(pred_obs, bins=15, density=True, alpha=0.5, color='teal', label = "Prior distribution")
plt.hist(p_posterior_samples, bins=15, density=True, alpha=0.5, color='orangered', label = "Posterior distribution")
x = jnp.linspace(0, 1, 3000)
kde = gaussian_kde(pred_obs)
plt.plot(x, kde(x), color='teal', lw=3, alpha=0.5)
kde = gaussian_kde(p_posterior_samples)
plt.plot(x, kde(x), color='orangered', lw=3, alpha=0.5)
plt.xlabel('p')
plt.xlim(0, 1)
plt.grid(0.3)
plt.legend()
plt.show()
Check convergence#
We now have obtained the samples from MCMC. How can we assess whether we can trust the results? Convergence diganostics survey this purpose. Beyond \(\hat{R}\), we can also visually inspect traceplots. Traceplots are simply sample values plotted against the iteration number. We want those traceplots to be stationary, i.e. they should look like a “hairy carterpillar”.
# inpect summary
# pring summary and look at R-hat
# r_hat is a dignostic comparing within chain variation to between chan variation.
# It is an importnat convergene diagnostic, and we want its valye to be close to 1
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
p 0.61 0.05 0.61 0.53 0.68 3059.08 1.00
Number of divergences: 0
# plot posterior distribution and traceplots
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
Posterior predictive distribution#
The posterior predictive distribution is a concept in Bayesian statistics that combines the information from both the observed data and the posterior distribution of model parameters to generate predictions for new, unseen data .
We can use the obtained samples obtained at the previous step to generate posterior predictive distribution on the outcome.
# using the same 'Predictive' class,
# but now specifying also `p_samples`
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model, p_samples)
posterior_predictions = predictive(rng_key_, n=n)
# extract prediction and calculate summary statistics
post_obs = posterior_predictions['obs']
mean_post_pred = jnp.mean(post_obs, axis=0)
hpdi_post_pred = hpdi(post_obs, 0.9)
# what is the mean number of successes?
mean_post_pred
Array(60.508003, dtype=float32)
# what is the unceratinty around this mean?
hpdi_post_pred
array([48, 70], dtype=int32)
Group Task
Change the hyperparameters of the model. How are they changing the results?
Change the data to
n=10
,h=6
. How does the result change now?
Bayesian Linear regression#
Now that we know how to use NumPyro. Let us build an example using larger amounts of data and build a Bayesian Linear Regression model. It is the same Linear Regression model you are familiar with, but here all of the parameters are estimated in the Bayesian way.
!wget -O Howell1.csv https://raw.githubusercontent.com/deep-learning-indaba/indaba-pracs-2023/main/data/Howell1.csv
df = pd.read_csv('Howell1.csv', sep=";")
df.head()
--2024-03-19 22:17:00-- https://raw.githubusercontent.com/deep-learning-indaba/indaba-pracs-2023/main/data/Howell1.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12205 (12K) [text/plain]
Saving to: ‘Howell1.csv’
Howell1.csv 100%[===================>] 11.92K --.-KB/s in 0.002s
2024-03-19 22:17:00 (5.75 MB/s) - ‘Howell1.csv’ saved [12205/12205]
height | weight | age | male | |
---|---|---|---|---|
0 | 151.765 | 47.825606 | 63.0 | 1 |
1 | 139.700 | 36.485807 | 63.0 | 0 |
2 | 136.525 | 31.864838 | 65.0 | 0 |
3 | 156.845 | 53.041914 | 41.0 | 1 |
4 | 145.415 | 41.276872 | 51.0 | 0 |
# observed data
weight = df.weight.values
height = df.height.values
Let us define some test data for the variable weight
. For these datapoints we will make predictions.
# data to make predictions for
weight_pred = jnp.array([45, 40, 65, 31, 53])
# plot the data
plt.figure(figsize=(6, 4))
plt.scatter(x='weight', y='height', data=df, color='teal')
plt.grid(0.3)
The linear regression model will have the form
Here \(y\) is the data we want to predict, \(x\) is the predictor, \(b_0\) is the bias (intercept), \(b_1\) is the slope (weight) and \(\sigma^2\) is variance.
Group task
Discuss which priors would be reasonable for the parameters \(b_0\), \(b_1\), \(\sigma\).
# model
def model(weight=None, height=None):
# priors
b0 = numpyro.sample('b0', dist.Normal(120,50))
b1 = numpyro.sample('b1', dist.Normal(0,1))
sigma = numpyro.sample('sigma', dist.HalfNormal(10.))
# deterministic transformation
mu = b0 + b1 * weight
# likelihood: notice `obs=height`
numpyro.sample('obs', dist.Normal(mu, sigma), obs=height)
# prior predictive
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
prior_predictive = Predictive(model, num_samples=100)
prior_predictions = prior_predictive(rng_key_, weight)
prior_predictions.keys()
dict_keys(['b0', 'b1', 'obs', 'sigma'])
pred_obs = prior_predictions['obs']
mean_prior_pred = jnp.mean(pred_obs, axis=0)
hpdi_prior_pred = hpdi(pred_obs, 0.89)
def plot_regression(x, y_mean, y_hpdi, height, ttl='Predictions with 89% CI)'):
# Sort values for plotting by x axis
idx = jnp.argsort(x)
weight = x[idx]
mean = y_mean[idx]
hpdi = y_hpdi[:, idx]
ht = height[idx]
# Plot
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
ax.plot(weight, mean, color='teal')
ax.plot(weight, ht, 'o', color='orangered')
ax.fill_between(weight, hpdi[0], hpdi[1], alpha=0.3, interpolate=True, color='teal')
ax.set(xlabel='weight', ylabel='height', title=ttl);
return ax
ax = plot_regression(weight, mean_prior_pred, hpdi_prior_pred, height, ttl="Prior predictive")
Group Task
Change the prior for \(b_0\) to dist.Normal(120,50)
and visualise the prior predictive distribution. Does it look satisfactory? Do you think we can go ahead with inference using this set of priors?
# Inference
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
# Run NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
mcmc.run(rng_key_, weight=weight, height=height)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
<ipython-input-28-8147f48aa0a5>:7: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=False)
mean std median 5.0% 95.0% n_eff r_hat
b0 75.49 1.07 75.48 73.72 77.26 2608.93 1.00
b1 1.76 0.03 1.76 1.72 1.81 2692.35 1.00
sigma 9.37 0.29 9.36 8.86 9.82 3778.49 1.00
Number of divergences: 0
# Check convergence
mcmc.print_summary()
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
mean std median 5.0% 95.0% n_eff r_hat
b0 75.49 1.07 75.48 73.72 77.26 2608.93 1.00
b1 1.76 0.03 1.76 1.72 1.81 2692.35 1.00
sigma 9.37 0.29 9.36 8.86 9.82 3778.49 1.00
Number of divergences: 0
# Posterior predictive
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model, samples_1)
posterior_predictions = predictive(rng_key_, weight=weight)
post_obs = posterior_predictions['obs']
mean_post_pred = jnp.mean(post_obs, axis=0)
hpdi_post_pred = hpdi(post_obs, 0.9)
ax = plot_regression(weight, mean_post_pred, hpdi_post_pred, height, ttl="Posterior predictive")
ax.set(xlabel='weight (scaled)', ylabel='height (scaled)');
# predict for new data
predictive = Predictive(model, samples_1)
predictions = predictive(rng_key_, weight=weight_pred)['obs']
mean_pred = jnp.mean(predictions, axis=0)
hpdi_pred = hpdi(predictions, 0.89)
d = {'weight_pred': weight_pred, 'mean_pred': mean_pred, 'lower': hpdi_pred[0,], 'upper': hpdi_pred[1,]}
df_res = pd.DataFrame(data=d)
df_res.head()
weight_pred | mean_pred | lower | upper | |
---|---|---|---|---|
0 | 45 | 154.772690 | 139.393387 | 169.614548 |
1 | 40 | 146.176010 | 131.200302 | 161.135223 |
2 | 65 | 190.045502 | 175.474350 | 204.991074 |
3 | 31 | 130.221954 | 115.344894 | 145.279709 |
4 | 53 | 168.721436 | 154.352890 | 184.339874 |
Task 18
Modify the model so that it fits better.
Hint: apply a transformation to input data, e.g. a polynomial.
For this model,
plot prior predictive distribution,
perform inference,
plot posterior predictive distribution.
Why use Bayesian regression over frequentist?#
Uncertainty estimation: Bayesian regression provides a natural way to quantify uncertainty in the estimates of the model parameters. This can be useful for making probabilistic predictions and for decision-making under uncertainty. Frequentist methods typically provide point estimates and, in the best case, confidence intervals (CIs). Here we get Bayesian credible intervals (BCIs).
Handling small sample sizes: Bayesian methods can be particularly useful when dealing with small sample sizes, as they allow for more flexible modeling assumptions and can incorporate prior information to help stabilize estimates. Frequentist methods may struggle with small sample sizes, leading to unreliable estimates and inflated uncertainty.
Regularization: Bayesian regression naturally incorporates regularization techniques, such as priors that penalize large parameter values, which can help prevent overfitting and improve the generalization performance of the model. While frequentist methods also offer regularization techniques, such as Ridge regression and Lasso regression, the Bayesian framework provides a principled way to specify and interpret regularization.
Let’s unpack the regularization argument a bit more.
The linear regression model is given by:
where \(y_i\) is the response variable, \(x_i\) is the predictor variable, \(\beta_0\) and \(\beta_1\) are the regression coefficients, and \(\epsilon_i\) are idependent and normally distributed errors.
We assume normal priors for \(\beta_0\) and \(\beta_1\):
Here, \(\alpha_0\), \(\alpha_1\), \(\tau_0^2\), and \(\tau_1^2\) are hyperparameters that need to be specified. The choice of these hyperparameters can influence the strength of regularization.
The likelihood function, assuming normal errors, is:
And the priors: $\( p(\beta_0) = \frac{1}{\sqrt{2\pi\tau_0^2}} \exp\left(-\frac{(\beta_0 - \mu_0)^2}{2\tau_0^2}\right) \)\( \)\( p(\beta_1) = \frac{1}{\sqrt{2\pi\tau_1^2}} \exp\left(-\frac{(\beta_1 - \mu_1)^2}{2\tau_1^2}\right) \)$
Applying Bayes’ theorem, we get the unnormalized posterior:
Computing the log-posterior, we get
Here the first term is the mean squared error (MSE) which usually serves as an objective in frequentist machine learning, and the second term proved regularization.
Regularization is controlled through the choice of hyperparameters \(\tau_0^2\) and \(\tau_1^2\). Larger values of \(\tau_0^2\) and \(\tau_1^2\) result in stronger regularization, pushing the estimated coefficients towards zero and simplifying the model.
Andrew Gelman, Aki Vehtari, Daniel Simpson, Charles C Margossian, Bob Carpenter, Yuling Yao, Lauren Kennedy, Jonah Gabry, Paul-Christian Bürkner, and Martin Modrák. Bayesian workflow. arXiv preprint arXiv:2011.01808, 2020.