Stable with Fn.Batch Shape
Total Page:16
File Type:pdf, Size:1020Kb
NumPyro Documentation Uber AI Labs Sep 26, 2021 CONTENTS 1 Getting Started with NumPyro1 2 Pyro Primitives 11 3 Distributions 23 4 Inference 95 5 Effect Handlers 167 6 Contributed Code 177 7 Bayesian Regression Using NumPyro 179 8 Bayesian Hierarchical Linear Regression 205 9 Example: Baseball Batting Average 215 10 Example: Variational Autoencoder 221 11 Example: Neal’s Funnel 225 12 Example: Stochastic Volatility 229 13 Example: ProdLDA with Flax and Haiku 233 14 Automatic rendering of NumPyro models 241 15 Bad posterior geometry and how to deal with it 249 16 Example: Bayesian Models of Annotation 257 17 Example: Enumerate Hidden Markov Model 265 18 Example: CJS Capture-Recapture Model for Ecological Data 273 19 Example: Nested Sampling for Gaussian Shells 281 20 Bayesian Imputation for Missing Values in Discrete Covariates 285 21 Time Series Forecasting 297 22 Ordinal Regression 305 i 23 Bayesian Imputation 311 24 Example: Gaussian Process 323 25 Example: Bayesian Neural Network 329 26 Example: AutoDAIS 333 27 Example: Sparse Regression 339 28 Example: Horseshoe Regression 349 29 Example: Proportion Test 353 30 Example: Generalized Linear Mixed Models 357 31 Example: Hamiltonian Monte Carlo with Energy Conserving Subsampling 361 32 Example: Hidden Markov Model 365 33 Example: Hilbert space approximation for Gaussian processes. 373 34 Example: Predator-Prey Model 389 35 Example: Neural Transport 393 36 Example: MCMC Methods for Tall Data 399 37 Example: Thompson sampling for Bayesian Optimization with GPs 405 38 Indices and tables 413 Python Module Index 415 Index 417 ii CHAPTER ONE GETTING STARTED WITH NUMPYRO Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. Docs and Examples| Forum 1.1 What is NumPyro? NumPyro is a small probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU. This is an alpha release under active development,so beware of brittleness, bugs, and changes to the API as the design evolves. NumPyro is designed to be lightweight and focuses on providing a flexible substrate that users can build on: • Pyro Primitives: NumPyro programs can contain regular Python and NumPy code, in addition to Pyro primitives like sample and param. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy’s API. See the example below. • Inference algorithms: NumPyro currently supports Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose jit and grad to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using Iterative NUTS). There is also a basic Variational Inference implementation for reparameterized distributions together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). • Distributions: The numpyro.distributions module provides distribution classes, constraints and bijective trans- forms. The distribution classes wrap over samplers implemented to work with JAX’s functional pseudo-random number generator. The design of the distributions module largely follows from PyTorch. A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in torch.distributions. In addition to distributions, constraints and transforms are very useful when operating on distribution classes with bounded support. • Effect handlers: Like Pyro, primitives like sample and param can be provided nonstandard interpretations using effect-handlers from the numpyro.handlers module, and these can be easily extended to implement custom inference algorithms and inference utilities. 1 NumPyro Documentation 1.2 A Simple Example - 8 Schools Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools. The data is given by: >>> import numpy as np >>> J=8 >>> y= np.array([28.0, 8.0,-3.0, 7.0,-1.0, 1.0, 18.0, 12.0]) >>> sigma= np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) , where y are the treatment effects and sigma the standard error. We build a hierarchical model for the study where we assume that the group-level parameters theta for each school are sampled from a Normal distribution with unknown mean mu and standard deviation tau, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by theta (true effect) and sigma, respectively. This allows us to estimate the population- level parameters mu and tau by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level theta parameters. >>> import numpyro >>> import numpyro.distributions as dist >>> # Eight Schools example ... def eight_schools(J, sigma, y=None): ... mu= numpyro.sample( 'mu', dist.Normal(0,5)) ... tau= numpyro.sample( 'tau', dist.HalfCauchy(5)) ... with numpyro.plate('J', J): ... theta= numpyro.sample( 'theta', dist.Normal(mu, tau)) ... numpyro.sample('obs', dist.Normal(theta, sigma), obs=y) Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the extra_fields argument in MCMC.run. By default, we only collect samples from the target (posterior) distribution when we run inference using MCMC. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the extra_fields argument. For a list of possible fields that can be collected, see the HMCState object. In this example, we will additionally collect the potential_energy for each sample. >>> from jax import random >>> from numpyro.infer import MCMC, NUTS (continues on next page) 2 Chapter 1. Getting Started with NumPyro NumPyro Documentation (continued from previous page) >>> nuts_kernel= NUTS(eight_schools) >>> mcmc= MCMC(nuts_kernel, num_warmup=500, num_samples=1000) >>> rng_key= random.PRNGKey(0) >>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',)) We can print the summary of the MCMC run, and examine if we observed any divergences during inference. Addi- tionally, since we collected the potential energy for each of the samples, we can easily compute the expected log joint density. >>> mcmc.print_summary() mean std median 5.0% 95.0% n_eff r_hat mu 4.14 3.18 3.87 -0.76 9.50 115.42 1.01 tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02 theta[0] 6.40 6.22 5.36 -2.54 15.27 176.75 1.00 theta[1] 4.96 5.04 4.49 -1.98 14.22 217.12 1.00 theta[2] 3.65 5.41 3.31 -3.47 13.77 247.64 1.00 theta[3] 4.47 5.29 4.00 -3.22 12.92 213.36 1.01 theta[4] 3.22 4.61 3.28 -3.72 10.93 242.14 1.01 theta[5] 3.89 4.99 3.71 -3.39 12.54 206.27 1.00 theta[6] 6.55 5.72 5.66 -1.43 15.78 124.57 1.00 theta[7] 4.81 5.95 4.19 -3.90 13.40 299.66 1.00 Number of divergences: 19 >>> pe= mcmc.get_extra_fields()[ 'potential_energy'] >>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe))) Expected log joint density: -54.55 1.2. A Simple Example - 8 Schools 3 NumPyro Documentation The values above 1 for the split Gelman Rubin diagnostic (r_hat) indicates that the chain has not fully converged. The low value for the effective sample size (n_eff), particularly for tau, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a non-centered paramaterization for tau in our model. This is straightforward to do in NumPyro by using a TransformedDistribution instance together with a reparameterization effect handler. Let us rewrite the same model but instead of sampling theta from a Normal(mu, tau), we will instead sample it from a base Normal(0, 1) distribution that is transformed using an AffineTransform. Note that by doing so, NumPyro runs HMC by generating samples theta_base for the base Normal(0, 1) distribu- tion instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good! >>> from numpyro.infer.reparam import TransformReparam >>> # Eight Schools example - Non-centered Reparametrization ... def eight_schools_noncentered(J, sigma, y=None): ... mu= numpyro.sample( 'mu', dist.Normal(0,5)) ... tau= numpyro.sample( 'tau', dist.HalfCauchy(5)) ... with numpyro.plate('J', J): ... with numpyro.handlers.reparam(config={'theta': TransformReparam()}): ... theta= numpyro.sample( ... 'theta', ... dist.TransformedDistribution(dist.Normal(0., 1.), ... dist.transforms.AffineTransform(mu, ,!tau))) ... numpyro.sample('obs', dist.Normal(theta, sigma), obs=y) >>> nuts_kernel= NUTS(eight_schools_noncentered) >>> mcmc= MCMC(nuts_kernel, num_warmup=500, num_samples=1000) >>> rng_key= random.PRNGKey(0) >>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',)) >>> mcmc.print_summary(exclude_deterministic=False) mean std median 5.0% 95.0% n_eff r_hat (continues on next page) 4 Chapter 1. Getting Started with NumPyro NumPyro Documentation (continued from previous