Logistic and other regressions#
In this section we will implement some of the GLMs.
Task 21
Notice that we are importing one new item from numpyro.infer
this time: init_to_median
. Research what it is doing. What are the available alternatives?
import time
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_median
import matplotlib.pyplot as plt
import arviz as az
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
rng_key = random.PRNGKey(67)
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Logistic regression: one dimensional version#
Let us simulate some data for Logistic regression and define a logistic regression model using Numpyro. We will need to choose priors for the intercept alpha
and for the coefficients beta
. We then will use the NUTS sampler to obtain posterior samples for alpha
and beta
from the Bayesian model. Finally, we will print the posterior means of the parameters.
# Generate synthetic data
np.random.seed(42)
X = np.random.randn(100, 1)
true_beta = jnp.array([ -2.0])
true_alpha = 0.5
logits = jnp.dot(X, true_beta) + true_alpha
probs = 1.0 / (1.0 + jnp.exp(-logits))
y = np.random.binomial(1, probs)
plt.scatter(X, y)
<matplotlib.collections.PathCollection at 0x7f3be0c68d50>
# Define the logistic regression model
def logistic_regression_model(X, y=None):
# dimesionality of X, i.e the number of features
num_features = X.shape[1]
# nummber of data points
num_data = X.shape[0]
# Priors
alpha = numpyro.sample('alpha', dist.Normal(0, 1))
beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))
# precompute logits, i.e. the linear predictor
logits = alpha + jnp.dot(X, beta)
# likelihood. Remember how to use plates?
with numpyro.plate('data', num_data):
numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
# Define the number of MCMC samples and the number of warmup steps
num_samples = 1000
num_warmup = 500
# Run NUTS sampler
nuts_kernel = NUTS(logistic_regression_model)
mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, progress_bar=False)
mcmc.run(rng_key, X=X, y=y)
# Get posterior samples
samples = mcmc.get_samples()
# Print posterior statistics
print("Posterior mean of alpha:", jnp.mean(samples['alpha']))
#print("Posterior mean of beta:", jnp.mean(samples['beta'], axis=0))
# mean is not enough
mcmc.print_summary()
# plot posterior distribution and traceplots
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
plt.tight_layout()
Posterior mean of alpha: 0.73502517
mean std median 5.0% 95.0% n_eff r_hat
alpha 0.74 0.26 0.73 0.33 1.17 557.81 1.01
beta[0] -2.09 0.40 -2.07 -2.70 -1.38 558.07 1.00
Number of divergences: 0
Writing a general function for MCMC inference flow#
Note that the Numpyro model which we wrote is generic with respect to dimentionality of X
(well done us!).
However, we have already repeated the same code several times. Let us wrap the inference flow into a function, and then apply to the case with two features and weights.
def run_mcmc(rng_key, # random key
model, # Numpyro model
args, # Dictionary of arguments
verbose=True # boolean for verbose MCMC
):
init_strategy = init_to_median(num_samples=10)
kernel = NUTS(model, init_strategy=init_strategy)
mcmc = MCMC(
kernel,
num_warmup=args["num_warmup"],
num_samples=args["num_mcmc_samples"],
num_chains=args["num_chains"],
thinning=args["thinning"],
progress_bar=False
)
start = time.time()
mcmc.run(rng_key, args)
t_elapsed = time.time() - start
if verbose:
mcmc.print_summary(exclude_deterministic=False)
else:
mcmc.print_summary()
print("\nMCMC elapsed time:", round(t_elapsed), "s")
# plot posterior distribution and traceplots
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True)
plt.tight_layout()
return mcmc, mcmc.get_samples(), t_elapsed
As an input, rather than specofocally providing X
and y
, we will provide a dictionary args
with data, as well as other parameters for MCMC.
Logistic regression: two-dimensional version#
# Define the logistic regression model
def logistic_regression_model(args): # notice the `args`!
X = args["X"]
y = args["y"]
# dimesionality of X, i.e the number of features
num_features = X.shape[1]
# nummber of data points
num_data = X.shape[0]
# Priors
alpha = numpyro.sample('alpha', dist.Normal(0, 1))
beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))
# precompute logits, i.e. the linear predictor
logits = alpha + jnp.dot(X, beta)
# likelihood. Remember how to use plates?
with numpyro.plate('data', num_data):
numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
# Generate synthetic data
np.random.seed(42)
X = np.random.randn(100, 2)
true_beta = jnp.array([1.0, -2.0])
true_alpha = 0.5
logits = jnp.dot(X, true_beta) + true_alpha
probs = 1.0 / (1.0 + jnp.exp(-logits))
y = np.random.binomial(1, probs)
args = {'X': X,
'y':y,
'num_mcmc_samples': 1000,
'num_warmup': 500,
'num_chains': 4,
'thinning': 1,
}
run_mcmc(rng_key, logistic_regression_model, args)
mean std median 5.0% 95.0% n_eff r_hat
alpha 0.65 0.28 0.65 0.17 1.08 3011.16 1.00
beta[0] 0.86 0.31 0.86 0.36 1.37 3146.67 1.00
beta[1] -2.05 0.39 -2.03 -2.67 -1.42 3029.42 1.00
Number of divergences: 0
MCMC elapsed time: 2 s
(<numpyro.infer.mcmc.MCMC at 0x7f3bdbe15fd0>,
{'alpha': Array([1.1911314 , 0.48238117, 1.0410734 , ..., 0.875891 , 0.76256436,
0.742151 ], dtype=float32),
'beta': Array([[ 1.004704 , -2.6107154 ],
[ 0.34020698, -1.7390003 ],
[ 1.6303924 , -2.870154 ],
...,
[ 0.5478067 , -1.8294735 ],
[ 0.57123476, -1.874356 ],
[ 0.9720275 , -2.3105533 ]], dtype=float32)},
1.8369851112365723)
Poisson regression#
# Generate synthetic data
np.random.seed(42)
X = np.random.randn(1000, 2)
true_beta = jnp.array([0.5, -1.5])
true_alpha = 1.0
true_lambda = jnp.exp(true_alpha + jnp.dot(X, true_beta))
y = np.random.poisson(true_lambda)
# Define the Poisson regression model
def poisson_regression_model(args):
X = args["X"]
y = args["y"]
# dimesionality of X, i.e the number of features
num_features = X.shape[1]
# nummber of data points
num_data = X.shape[0]
# priors
alpha = numpyro.sample('alpha', dist.Normal(0, 1))
beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))
# Poisson regression
lambda_ = jnp.exp(alpha + jnp.dot(X, beta))
# likelihood
with numpyro.plate('data', num_data):
numpyro.sample('obs', dist.Poisson(lambda_), obs=y)
args = {'X': X,
'y':y,
'num_mcmc_samples': 1000,
'num_warmup': 500,
'num_chains': 2,
'thinning': 1,
}
run_mcmc(rng_key, poisson_regression_model, args)
mean std median 5.0% 95.0% n_eff r_hat
alpha 0.98 0.02 0.98 0.95 1.02 557.44 1.00
beta[0] 0.49 0.01 0.49 0.47 0.51 1216.08 1.00
beta[1] -1.52 0.01 -1.52 -1.54 -1.50 655.06 1.00
Number of divergences: 0
MCMC elapsed time: 2 s
(<numpyro.infer.mcmc.MCMC at 0x7f3bdb82fb50>,
{'alpha': Array([1.002402 , 0.9481383 , 0.9553949 , ..., 0.9836497 , 0.9768003 ,
0.97007227], dtype=float32),
'beta': Array([[ 0.49739638, -1.5155102 ],
[ 0.49268013, -1.5444417 ],
[ 0.49289092, -1.5341926 ],
...,
[ 0.517036 , -1.5129181 ],
[ 0.49180886, -1.5197647 ],
[ 0.48897225, -1.5250332 ]], dtype=float32)},
1.7879128456115723)
Binomial regression#
# Generate synthetic data
np.random.seed(42)
num_samples = 100
X = np.random.randn(num_samples, 2)
true_beta = np.array([1.0, -2.0])
true_alpha = 1.0
logits = true_alpha + X.dot(true_beta)
num_trials = np.random.randint(1, 10, size=num_samples) # Vector of different numbers of trials
y = np.random.binomial(num_trials, p=1 / (1 + np.exp(-logits)))
# Define the binomial regression model
def binomial_regression_model(args):
num_samples, num_features = X.shape
# Priors for the coefficients
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))
# Linear predictor
eta = alpha + jnp.dot(X, beta)
# Likelihood
with numpyro.plate('data', num_samples):
numpyro.sample('obs', dist.Binomial(total_count=num_trials, logits=eta), obs=y)
args = {'X': X,
'y':y,
'num_trials': num_trials,
'num_mcmc_samples': 1000,
'num_warmup': 500,
'num_chains': 2,
'thinning': 1,
}
run_mcmc(rng_key, binomial_regression_model, args)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[15], line 10
1 args = {'X': X,
2 'y':y,
3 'num_trials': num_trials,
(...)
7 'thinning': 1,
8 }
---> 10 run_mcmc(rng_key, binomial_regression_model, args)
Cell In[6], line 18, in run_mcmc(rng_key, model, args, verbose)
9 mcmc = MCMC(
10 kernel,
11 num_warmup=args["num_warmup"],
(...)
15 progress_bar=False
16 )
17 start = time.time()
---> 18 mcmc.run(rng_key, args)
19 t_elapsed = time.time() - start
20 if verbose:
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:688, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
686 states, last_state = _laxmap(partial_map_fn, map_args)
687 elif self.chain_method == "parallel":
--> 688 states, last_state = pmap(partial_map_fn)(map_args)
689 elif callable(self.chain_method):
690 states, last_state = self.chain_method(partial_map_fn)(map_args)
[... skipping hidden 11 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:443, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
441 # Check if _sample_fn is None, then we need to initialize the sampler.
442 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 443 new_init_state = self.sampler.init(
444 rng_key,
445 self.num_warmup,
446 init_params,
447 model_args=args,
448 model_kwargs=kwargs,
449 )
450 init_state = new_init_state if init_state is None else init_state
451 sample_fn, postprocess_fn = self._get_cached_fns()
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:749, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
744 # vectorized
745 else:
746 rng_key, rng_key_init_model = jnp.swapaxes(
747 vmap(random.split)(rng_key), 0, 1
748 )
--> 749 init_params = self._init_state(
750 rng_key_init_model, model_args, model_kwargs, init_params
751 )
752 if self._potential_fn and init_params is None:
753 raise ValueError(
754 "Valid value of `init_params` must be provided with" " `potential_fn`."
755 )
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:693, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
686 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
687 if self._model is not None:
688 (
689 new_init_params,
690 potential_fn,
691 postprocess_fn,
692 model_trace,
--> 693 ) = initialize_model(
694 rng_key,
695 self._model,
696 dynamic_args=True,
697 init_strategy=self._init_strategy,
698 model_args=model_args,
699 model_kwargs=model_kwargs,
700 forward_mode_differentiation=self._forward_mode_differentiation,
701 )
702 if init_params is None:
703 init_params = new_init_params
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:712, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
710 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
711 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 712 (init_params, pe, grad), is_valid = find_valid_initial_params(
713 rng_key,
714 substitute(
715 model,
716 data={
717 k: site["value"]
718 for k, site in model_trace.items()
719 if site["type"] in ["plate"]
720 },
721 ),
722 init_strategy=init_strategy,
723 enum=has_enumerate_support,
724 model_args=model_args,
725 model_kwargs=model_kwargs,
726 prototype_params=prototype_params,
727 forward_mode_differentiation=forward_mode_differentiation,
728 validate_grad=validate_grad,
729 )
731 if not_jax_tracer(is_valid):
732 if device_get(~jnp.all(is_valid)):
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:446, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
444 # Handle possible vectorization
445 if is_prng_key(rng_key):
--> 446 (init_params, pe, z_grad), is_valid = _find_valid_params(
447 rng_key, exit_early=True
448 )
449 else:
450 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:439, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
435 return (init_params, pe, z_grad), is_valid
437 # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
438 # even if the init_state is a valid result
--> 439 _, _, (init_params, pe, z_grad), is_valid = while_loop(
440 cond_fn, body_fn, init_state
441 )
442 return (init_params, pe, z_grad), is_valid
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/util.py:138, in while_loop(cond_fun, body_fun, init_val)
136 return val
137 else:
--> 138 return lax.while_loop(cond_fun, body_fun, init_val)
[... skipping hidden 9 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:416, in find_valid_initial_params.<locals>.body_fn(state)
414 z_grad = jacfwd(potential_fn)(params)
415 else:
--> 416 pe, z_grad = value_and_grad(potential_fn)(params)
417 z_grad_flat = ravel_pytree(z_grad)[0]
418 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:298, in potential_energy(model, model_args, model_kwargs, params, enum)
294 substituted_model = substitute(
295 model, substitute_fn=partial(_unconstrain_reparam, params)
296 )
297 # no param is needed for log_density computation because we already substitute
--> 298 log_joint, model_trace = log_density_(
299 substituted_model, model_args, model_kwargs, {}
300 )
301 return -log_joint
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:91, in log_density(model, model_args, model_kwargs, params)
85 except ValueError:
86 raise ValueError(
87 "Model and guide shapes disagree at site: '{}': {} vs {}".format(
88 site["name"], model_shape, guide_shape
89 )
90 )
---> 91 log_prob = site["fn"].log_prob(value)
93 if (scale is not None) and (not is_identically_one(scale)):
94 log_prob = scale * log_prob
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/util.py:706, in validate_sample.<locals>.wrapper(self, *args, **kwargs)
705 def wrapper(self, *args, **kwargs):
--> 706 log_prob = log_prob_fn(self, *args, **kwargs)
707 if self._validate_args:
708 value = kwargs["value"] if "value" in kwargs else args[0]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/discrete.py:269, in BinomialLogits.log_prob(self, value)
265 log_factorial_k = gammaln(value + 1)
266 log_factorial_nmk = gammaln(self.total_count - value + 1)
267 normalize_term = (
268 self.total_count * jnp.clip(self.logits, 0)
--> 269 + xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits)))
270 - log_factorial_n
271 )
272 return (
273 value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
274 )
[... skipping hidden 5 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/scipy/special.py:536, in _xlog1py_jvp(primals, tangents)
534 (x_dot, y_dot) = tangents
535 result = xlog1py(x, y)
--> 536 return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:1036, in _forward_operator_to_aval.<locals>.op(self, *args)
1035 def op(self, *args):
-> 1036 return getattr(self.aval, f"_{name}")(self, *args)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
571 args = (other, self) if swap else (self, other)
572 if isinstance(other, _accepted_binop_types):
--> 573 return binary_op(*args)
574 # Note: don't use isinstance here, because we don't want to raise for
575 # subclasses, e.g. NamedTuple objects that may override operators.
576 if type(other) in _rejected_binop_types:
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
175 raise NotImplementedError(f"where argument of {self}")
176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)
[... skipping hidden 11 frame]
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
1115 @partial(jit, inline=True)
1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
1117 """Multiply two arrays element-wise.
1118
1119 JAX implementation of :obj:`numpy.multiply`. This is a universal function,
(...)
1140 Array([ 0, 10, 20, 30], dtype=int32)
1141 """
-> 1142 x, y = promote_args("multiply", x, y)
1143 return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
353 """Convenience function to apply Numpy argument shape and dtype promotion."""
354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
356 check_for_prngkeys(fun_name, *args)
357 return promote_shapes(fun_name, *promote_dtypes(*args))
File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
324 """Check if none of the args have dtype float0."""
325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326 raise TypeError(
327 f"Called {fun_name} with a float0 array. "
328 "float0s do not support any operations by design because they "
329 "are not compatible with non-trivial vector spaces. No implicit dtype "
330 "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
331 "to cast a float0 array to a regular zeros array. \n"
332 "If you didn't expect to get a float0 you might have accidentally "
333 "taken a gradient with respect to an integer argument.")
TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
Negative Binomial regression#
So far, we have seen only one distribution able to describe counts - the Poisson distribution. While very appealing, this distribution has such a drawback that its mean and variance are equal \(E[y] = \text{Var}[y] = \lambda\). This would not always be a good modelling choice!
The negative binomial distribution can serve as an alternative. It is also appropriate for modelling of counts, however, it allows for the mean and variance not to be equal. There exist more than one parametrizations of the negative binomial distribution. The most intuitive one is \(\mathcal{NegBin2}\) parametrisation since it connects the mean and varinace in an intuitive way:
In the limit \(\phi \to \infty\) one gets the Poisson distribution; \(\frac{\mu^2}{\phi}\) is the additional variance of the negative binomial above that of the Poisson. Parameter \(\frac{1}{\phi}\) is the overdispersion.
Group Task
Demonstrate negative binomial regression by following these steps:
simulate feature matrix
X
of dimensionality (100, 5)set true values of the intercept
alpha
, coefficientsbeta
and parameterphi
simulate realisations
y
with these valuesconstruct a Numpyro model to estimate
alpha
,beta
andphi
from dataX
andy
fit a Poisson model to the same data. How do the two fits compare?