Logistic and other regressions#

In this section we will implement some of the GLMs.

Task 21

Notice that we are importing one new item from numpyro.infer this time: init_to_median. Research what it is doing. What are the available alternatives?

import time
import os

import numpy as np

import jax
import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_median

import matplotlib.pyplot as plt
import arviz as az

numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)

rng_key = random.PRNGKey(67)
/opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Logistic regression: one dimensional version#

Let us simulate some data for Logistic regression and define a logistic regression model using Numpyro. We will need to choose priors for the intercept alpha and for the coefficients beta. We then will use the NUTS sampler to obtain posterior samples for alpha and beta from the Bayesian model. Finally, we will print the posterior means of the parameters.

# Generate  synthetic data
np.random.seed(42)
X = np.random.randn(100, 1)
true_beta = jnp.array([ -2.0])
true_alpha = 0.5
logits = jnp.dot(X, true_beta) + true_alpha
probs = 1.0 / (1.0 + jnp.exp(-logits))
y = np.random.binomial(1, probs)

plt.scatter(X, y)
<matplotlib.collections.PathCollection at 0x7f3be0c68d50>
_images/2f2406ec278a9f6293b1c261f0ff592abd41c4a3e408d656b9ac804100c3b2a0.png
# Define the logistic regression model
def logistic_regression_model(X, y=None):

    # dimesionality of X, i.e the number of features
    num_features = X.shape[1]

    # nummber of data points
    num_data = X.shape[0]

    # Priors
    alpha = numpyro.sample('alpha', dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))

    # precompute logits, i.e. the linear predictor
    logits = alpha + jnp.dot(X, beta)

    # likelihood. Remember how to use plates?
    with numpyro.plate('data', num_data):
        numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
# Define the number of MCMC samples and the number of warmup steps
num_samples = 1000
num_warmup = 500

# Run NUTS sampler
nuts_kernel = NUTS(logistic_regression_model)
mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, progress_bar=False)

mcmc.run(rng_key, X=X, y=y)
# Get posterior samples
samples = mcmc.get_samples()

# Print posterior statistics
print("Posterior mean of alpha:", jnp.mean(samples['alpha']))
#print("Posterior mean of beta:", jnp.mean(samples['beta'], axis=0))

# mean is not enough
mcmc.print_summary()

# plot posterior distribution and traceplots
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
plt.tight_layout()
Posterior mean of alpha: 0.73502517

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha      0.74      0.26      0.73      0.33      1.17    557.81      1.01
   beta[0]     -2.09      0.40     -2.07     -2.70     -1.38    558.07      1.00

Number of divergences: 0
_images/b782f373035a9b9a4cddcf9d7f7b41256c8b7d759b85f7c8bf31823b2602accd.png

Writing a general function for MCMC inference flow#

Note that the Numpyro model which we wrote is generic with respect to dimentionality of X (well done us!).

However, we have already repeated the same code several times. Let us wrap the inference flow into a function, and then apply to the case with two features and weights.

def run_mcmc(rng_key,       # random key
             model,         # Numpyro model
             args,          # Dictionary of arguments
             verbose=True   # boolean for verbose MCMC
            ):
    
    init_strategy = init_to_median(num_samples=10)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(
        kernel,
        num_warmup=args["num_warmup"],
        num_samples=args["num_mcmc_samples"],
        num_chains=args["num_chains"],
        thinning=args["thinning"],
        progress_bar=False
    )
    start = time.time()
    mcmc.run(rng_key, args)
    t_elapsed = time.time() - start
    if verbose:
        mcmc.print_summary(exclude_deterministic=False)
    else:
        mcmc.print_summary()

    print("\nMCMC elapsed time:", round(t_elapsed), "s")

    # plot posterior distribution and traceplots
    data = az.from_numpyro(mcmc)
    az.plot_trace(data, compact=True)
    plt.tight_layout()


    return mcmc, mcmc.get_samples(), t_elapsed

As an input, rather than specofocally providing X and y, we will provide a dictionary args with data, as well as other parameters for MCMC.

Logistic regression: two-dimensional version#

# Define the logistic regression model
def logistic_regression_model(args): # notice the `args`!

    X = args["X"]
    y = args["y"]

    # dimesionality of X, i.e the number of features
    num_features = X.shape[1]

    # nummber of data points
    num_data = X.shape[0]

    # Priors
    alpha = numpyro.sample('alpha', dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))

    # precompute logits, i.e. the linear predictor
    logits = alpha + jnp.dot(X, beta)

    # likelihood. Remember how to use plates?
    with numpyro.plate('data', num_data):
        numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
# Generate synthetic data
np.random.seed(42)
X = np.random.randn(100, 2)
true_beta = jnp.array([1.0, -2.0])
true_alpha = 0.5
logits = jnp.dot(X, true_beta) + true_alpha
probs = 1.0 / (1.0 + jnp.exp(-logits))
y = np.random.binomial(1, probs)
args = {'X': X, 
        'y':y,
        'num_mcmc_samples': 1000,
        'num_warmup': 500,
        'num_chains': 4, 
        'thinning': 1,
}

run_mcmc(rng_key, logistic_regression_model, args)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha      0.65      0.28      0.65      0.17      1.08   3011.16      1.00
   beta[0]      0.86      0.31      0.86      0.36      1.37   3146.67      1.00
   beta[1]     -2.05      0.39     -2.03     -2.67     -1.42   3029.42      1.00

Number of divergences: 0

MCMC elapsed time: 2 s
(<numpyro.infer.mcmc.MCMC at 0x7f3bdbe15fd0>,
 {'alpha': Array([1.1911314 , 0.48238117, 1.0410734 , ..., 0.875891  , 0.76256436,
         0.742151  ], dtype=float32),
  'beta': Array([[ 1.004704  , -2.6107154 ],
         [ 0.34020698, -1.7390003 ],
         [ 1.6303924 , -2.870154  ],
         ...,
         [ 0.5478067 , -1.8294735 ],
         [ 0.57123476, -1.874356  ],
         [ 0.9720275 , -2.3105533 ]], dtype=float32)},
 1.8369851112365723)
_images/f0a67bdae746a126ae9195d72a6cc80b550bc836e530c4dfde1dd81e510178bc.png

Poisson regression#

# Generate synthetic data
np.random.seed(42)
X = np.random.randn(1000, 2)
true_beta = jnp.array([0.5, -1.5])
true_alpha = 1.0

true_lambda = jnp.exp(true_alpha + jnp.dot(X, true_beta))
y = np.random.poisson(true_lambda)
# Define the Poisson regression model
def poisson_regression_model(args):

    X = args["X"]
    y = args["y"]

    # dimesionality of X, i.e the number of features
    num_features = X.shape[1]

    # nummber of data points
    num_data = X.shape[0]

    # priors
    alpha = numpyro.sample('alpha', dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))
    
    # Poisson regression
    lambda_ = jnp.exp(alpha + jnp.dot(X, beta))
    
    # likelihood
    with numpyro.plate('data', num_data):
        numpyro.sample('obs', dist.Poisson(lambda_), obs=y)
args = {'X': X, 
        'y':y,
        'num_mcmc_samples': 1000,
        'num_warmup': 500,
        'num_chains': 2, 
        'thinning': 1,
}

run_mcmc(rng_key, poisson_regression_model, args)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha      0.98      0.02      0.98      0.95      1.02    557.44      1.00
   beta[0]      0.49      0.01      0.49      0.47      0.51   1216.08      1.00
   beta[1]     -1.52      0.01     -1.52     -1.54     -1.50    655.06      1.00

Number of divergences: 0

MCMC elapsed time: 2 s
(<numpyro.infer.mcmc.MCMC at 0x7f3bdb82fb50>,
 {'alpha': Array([1.002402  , 0.9481383 , 0.9553949 , ..., 0.9836497 , 0.9768003 ,
         0.97007227], dtype=float32),
  'beta': Array([[ 0.49739638, -1.5155102 ],
         [ 0.49268013, -1.5444417 ],
         [ 0.49289092, -1.5341926 ],
         ...,
         [ 0.517036  , -1.5129181 ],
         [ 0.49180886, -1.5197647 ],
         [ 0.48897225, -1.5250332 ]], dtype=float32)},
 1.7879128456115723)
_images/1dd1c1a61f00982c7e579b671b9cfb96fcd662d7dc8387f2fcec008e94ecc2b4.png

Binomial regression#

# Generate synthetic data
np.random.seed(42)
num_samples = 100
X = np.random.randn(num_samples, 2)
true_beta = np.array([1.0, -2.0])
true_alpha = 1.0
logits = true_alpha + X.dot(true_beta)
num_trials = np.random.randint(1, 10, size=num_samples)  # Vector of different numbers of trials
y = np.random.binomial(num_trials, p=1 / (1 + np.exp(-logits)))
# Define the binomial regression model
def binomial_regression_model(args):

    num_samples, num_features = X.shape

    # Priors for the coefficients
    alpha = numpyro.sample("alpha", dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(jnp.zeros(num_features), jnp.ones(num_features)))

    # Linear predictor
    eta = alpha + jnp.dot(X, beta)

    # Likelihood
    with numpyro.plate('data', num_samples):
        numpyro.sample('obs', dist.Binomial(total_count=num_trials, logits=eta), obs=y)
args = {'X': X, 
        'y':y,
        'num_trials': num_trials,
        'num_mcmc_samples': 1000,
        'num_warmup': 500,
        'num_chains': 2, 
        'thinning': 1,
}

run_mcmc(rng_key, binomial_regression_model, args)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[15], line 10
      1 args = {'X': X, 
      2         'y':y,
      3         'num_trials': num_trials,
   (...)
      7         'thinning': 1,
      8 }
---> 10 run_mcmc(rng_key, binomial_regression_model, args)

Cell In[6], line 18, in run_mcmc(rng_key, model, args, verbose)
      9 mcmc = MCMC(
     10     kernel,
     11     num_warmup=args["num_warmup"],
   (...)
     15     progress_bar=False
     16 )
     17 start = time.time()
---> 18 mcmc.run(rng_key, args)
     19 t_elapsed = time.time() - start
     20 if verbose:

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:688, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    686     states, last_state = _laxmap(partial_map_fn, map_args)
    687 elif self.chain_method == "parallel":
--> 688     states, last_state = pmap(partial_map_fn)(map_args)
    689 elif callable(self.chain_method):
    690     states, last_state = self.chain_method(partial_map_fn)(map_args)

    [... skipping hidden 11 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/mcmc.py:443, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    441 # Check if _sample_fn is None, then we need to initialize the sampler.
    442 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 443     new_init_state = self.sampler.init(
    444         rng_key,
    445         self.num_warmup,
    446         init_params,
    447         model_args=args,
    448         model_kwargs=kwargs,
    449     )
    450     init_state = new_init_state if init_state is None else init_state
    451 sample_fn, postprocess_fn = self._get_cached_fns()

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:749, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    744 # vectorized
    745 else:
    746     rng_key, rng_key_init_model = jnp.swapaxes(
    747         vmap(random.split)(rng_key), 0, 1
    748     )
--> 749 init_params = self._init_state(
    750     rng_key_init_model, model_args, model_kwargs, init_params
    751 )
    752 if self._potential_fn and init_params is None:
    753     raise ValueError(
    754         "Valid value of `init_params` must be provided with" " `potential_fn`."
    755     )

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/hmc.py:693, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    686 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    687     if self._model is not None:
    688         (
    689             new_init_params,
    690             potential_fn,
    691             postprocess_fn,
    692             model_trace,
--> 693         ) = initialize_model(
    694             rng_key,
    695             self._model,
    696             dynamic_args=True,
    697             init_strategy=self._init_strategy,
    698             model_args=model_args,
    699             model_kwargs=model_kwargs,
    700             forward_mode_differentiation=self._forward_mode_differentiation,
    701         )
    702         if init_params is None:
    703             init_params = new_init_params

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:712, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    710     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    711 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 712 (init_params, pe, grad), is_valid = find_valid_initial_params(
    713     rng_key,
    714     substitute(
    715         model,
    716         data={
    717             k: site["value"]
    718             for k, site in model_trace.items()
    719             if site["type"] in ["plate"]
    720         },
    721     ),
    722     init_strategy=init_strategy,
    723     enum=has_enumerate_support,
    724     model_args=model_args,
    725     model_kwargs=model_kwargs,
    726     prototype_params=prototype_params,
    727     forward_mode_differentiation=forward_mode_differentiation,
    728     validate_grad=validate_grad,
    729 )
    731 if not_jax_tracer(is_valid):
    732     if device_get(~jnp.all(is_valid)):

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:446, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    444 # Handle possible vectorization
    445 if is_prng_key(rng_key):
--> 446     (init_params, pe, z_grad), is_valid = _find_valid_params(
    447         rng_key, exit_early=True
    448     )
    449 else:
    450     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:439, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    435             return (init_params, pe, z_grad), is_valid
    437 # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
    438 # even if the init_state is a valid result
--> 439 _, _, (init_params, pe, z_grad), is_valid = while_loop(
    440     cond_fn, body_fn, init_state
    441 )
    442 return (init_params, pe, z_grad), is_valid

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/util.py:138, in while_loop(cond_fun, body_fun, init_val)
    136     return val
    137 else:
--> 138     return lax.while_loop(cond_fun, body_fun, init_val)

    [... skipping hidden 9 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:416, in find_valid_initial_params.<locals>.body_fn(state)
    414     z_grad = jacfwd(potential_fn)(params)
    415 else:
--> 416     pe, z_grad = value_and_grad(potential_fn)(params)
    417 z_grad_flat = ravel_pytree(z_grad)[0]
    418 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:298, in potential_energy(model, model_args, model_kwargs, params, enum)
    294 substituted_model = substitute(
    295     model, substitute_fn=partial(_unconstrain_reparam, params)
    296 )
    297 # no param is needed for log_density computation because we already substitute
--> 298 log_joint, model_trace = log_density_(
    299     substituted_model, model_args, model_kwargs, {}
    300 )
    301 return -log_joint

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/infer/util.py:91, in log_density(model, model_args, model_kwargs, params)
     85     except ValueError:
     86         raise ValueError(
     87             "Model and guide shapes disagree at site: '{}': {} vs {}".format(
     88                 site["name"], model_shape, guide_shape
     89             )
     90         )
---> 91     log_prob = site["fn"].log_prob(value)
     93 if (scale is not None) and (not is_identically_one(scale)):
     94     log_prob = scale * log_prob

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/util.py:706, in validate_sample.<locals>.wrapper(self, *args, **kwargs)
    705 def wrapper(self, *args, **kwargs):
--> 706     log_prob = log_prob_fn(self, *args, **kwargs)
    707     if self._validate_args:
    708         value = kwargs["value"] if "value" in kwargs else args[0]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/numpyro/distributions/discrete.py:269, in BinomialLogits.log_prob(self, value)
    265 log_factorial_k = gammaln(value + 1)
    266 log_factorial_nmk = gammaln(self.total_count - value + 1)
    267 normalize_term = (
    268     self.total_count * jnp.clip(self.logits, 0)
--> 269     + xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits)))
    270     - log_factorial_n
    271 )
    272 return (
    273     value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
    274 )

    [... skipping hidden 5 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/scipy/special.py:536, in _xlog1py_jvp(primals, tangents)
    534 (x_dot, y_dot) = tangents
    535 result = xlog1py(x, y)
--> 536 return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:1036, in _forward_operator_to_aval.<locals>.op(self, *args)
   1035 def op(self, *args):
-> 1036   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    571 args = (other, self) if swap else (self, other)
    572 if isinstance(other, _accepted_binop_types):
--> 573   return binary_op(*args)
    574 # Note: don't use isinstance here, because we don't want to raise for
    575 # subclasses, e.g. NamedTuple objects that may override operators.
    576 if type(other) in _rejected_binop_types:

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
    175   raise NotImplementedError(f"where argument of {self}")
    176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)

    [... skipping hidden 11 frame]

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
   1115 @partial(jit, inline=True)
   1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
   1117   """Multiply two arrays element-wise.
   1118 
   1119   JAX implementation of :obj:`numpy.multiply`. This is a universal function,
   (...)
   1140     Array([ 0, 10, 20, 30], dtype=int32)
   1141   """
-> 1142   x, y = promote_args("multiply", x, y)
   1143   return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
    353 """Convenience function to apply Numpy argument shape and dtype promotion."""
    354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
    356 check_for_prngkeys(fun_name, *args)
    357 return promote_shapes(fun_name, *promote_dtypes(*args))

File /opt/hostedtoolcache/Python/3.11.10/x64/lib/python3.11/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
    324 """Check if none of the args have dtype float0."""
    325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326   raise TypeError(
    327       f"Called {fun_name} with a float0 array. "
    328       "float0s do not support any operations by design because they "
    329       "are not compatible with non-trivial vector spaces. No implicit dtype "
    330       "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
    331       "to cast a float0 array to a regular zeros array. \n"
    332       "If you didn't expect to get a float0 you might have accidentally "
    333       "taken a gradient with respect to an integer argument.")

TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

Negative Binomial regression#

So far, we have seen only one distribution able to describe counts - the Poisson distribution. While very appealing, this distribution has such a drawback that its mean and variance are equal \(E[y] = \text{Var}[y] = \lambda\). This would not always be a good modelling choice!

The negative binomial distribution can serve as an alternative. It is also appropriate for modelling of counts, however, it allows for the mean and variance not to be equal. There exist more than one parametrizations of the negative binomial distribution. The most intuitive one is \(\mathcal{NegBin2}\) parametrisation since it connects the mean and varinace in an intuitive way:

\[\begin{split} \begin{align*} E[y] &= \mu,\\ \text{Var}[y] &= \mu + \frac{\mu^2}{\phi}. \end{align*} \end{split}\]

In the limit \(\phi \to \infty\) one gets the Poisson distribution; \(\frac{\mu^2}{\phi}\) is the additional variance of the negative binomial above that of the Poisson. Parameter \(\frac{1}{\phi}\) is the overdispersion.

Group Task

Demonstrate negative binomial regression by following these steps:

  • simulate feature matrix X of dimensionality (100, 5)

  • set true values of the intercept alpha, coefficients beta and parameter phi

  • simulate realisations y with these values

  • construct a Numpyro model to estimate alpha, beta and phi from data X and y

  • fit a Poisson model to the same data. How do the two fits compare?