基于Numpy/JAX/JIT的Pyro(深度概率编程)
基于Numpy/JAX/JIT的Pyro(深度概率编程)
v0.2.4
New Features
- NumPyro can be used on Cloud TPUs.
- Deterministic primitive to record deterministic values in a model.
- Mask handler to mask out the log probability of a sample site using a mask array.
- Sample Adaptive MCMC, a non-gradient based sampler that has a high effective sample size per second.
- New normalizing flow: Block neural autoregressive transform.
- Additional auto guides: AutoLowRankMultivariateNormal and AutoBNAFNormal.
New Examples
- Predator-prey model example: uses MCMC to solve the inverse problem of an ODE system.
- Neural transport example: uses a normalizing flow to transform the posterior to a Gaussian-like one, thereby improving mixing rate for HMC/NUTS.
Deprecation / Breaking Changes
- Predictive's get_samples method is deprecated in favor of
__call__
method. - MCMC
constrain_fn
is renamed topostprocess_fn
.
Enhancements and Bug Fixes
- Change the init scale of Auto*Normal guides from 1. to 0.1 - this is helpful for stability during the early training phase.
- Resolve overflow issue with the Poisson sampler.
Assets
2
v0.2.3
Patches 0.2.2
with the following changes:
- restore compatibility with python 3.7 for
mcmc
. - impose cache size limit in MCMC utilities.
Assets
2
v0.2.2
Breaking changes
- Minor interface changes to MCMC utility functions. All experimental interfaces are marked as such in the documentation.
New Features
- A numpyro.factor primitive that adds an arbitrary log probability factor to a probabilistic model.
Enhancements and Bug Fixes
- Addressed a bug where multiple invocations of
MCMC.run
would wrongly use the previously cached arguments. MCMC
reuses compiled model code whenever possible. e.g. when re-running with different but same sized model arguments.- Ability to reuse the same warmup state for subsequent MCMC runs using MCMC.warmup.
Assets
2
v0.2.1
neerajprad released this
Breaking changes
- Code reorganization -
numpyro.mcmc
is moved tonumpyro.infer.mcmc
but all major classes are exposed in thenumpyro.infer
module. rng
argument to many classes and theseed
handler has been more accurately renamed torng_key
.- Deprecated functions that formed the old interface like
mcmc
andsvi
have been removed.
New Features
- Improved turning condition for NUTS that results in much higher effective sample size for many models.
- A numpyro.plate context manager, which records conditional independence information in the trace and does automatic broadcasting, like in Pyro.
- Inclusion of AutoMultivariateNormal, AutoLaplaceApproximation to the autoguide module.
- More distributions like LowRankMultivariateNormal, LKJ, BetaBinomial, GammaPoisson, ZeroInflatedPoisson, and OrderedLogistic.
- More transforms: MultivariateAffineTransform, InvCholeskyTransform, OrderedTransform.
- A
numpyro.compat
module that supports the pyro generic API for modeling and inference that can dispatch to multiple Pyro backends. - Inclusion of Independent distribution and
Distribution.to_event
method to convert independent batch dimensions to dependent event dimensions. See the Pyro tutorial on tensor shapes for more details. - A Predictive utility for generating samples from prior models, predictions from models using SVI's guide, or posterior samples from MCMC.
- A log_likelihood utility function that can compute the log likelihood for observed data by conditioning latent sites to values from the posterior distribution.
- New ClippedAdam optimizer to prevent exploding gradients.
- New RenyiELBO loss for Renyi divergence variational inference and importance weighted variational inference.
Enhancements and Bug Fixes
- MCMC does not throw an error on models with no latent sites.
- numpyro.seed handler can be used as a context manager like:
with numpyro.seed(rng_seed=1): ...
- Utilities to enable validation checks for distributions, set host device count, and platform.
- More efficient sampling from Binomial / Multinomial distributions.
- The evidence lower bound loss for SVI is now a class called
ELBO
. - Add
energy
field to HMCState, which is used to compute Bayesian Fraction of Missing Information for diagnostics. - Add
init_strategy
arg to HMC/NUTS classes, which allows users select various initialization strategies.
Assets
2
v0.2.0
neerajprad released this
Highlights
- Interface Changes to MCMC and SVI: The interface for inference algorithms have been simplified, and is much closer to Pyro. See MCMC and SVI.
- Multi-chain Sampling for MCMC: There are three options provided:
parallel
(default),sequential
, andvectorized
. Currently,parallel
method is the fastest among the three.
Breaking changes
- The primitives
param
,sample
are moved to primitives module. All primities are exposed innumpyro
namespace.
New Features
MCMC
- In MCMC, we have the option to collect fields other than just the samples such as number of steps or step size, using
collect_fields
arg in MCMC.run. This can be useful when gathering diagnostic information during debugging. diverging
field is added to HMCState. This field is useful to detect divergent transitions.- Support improper prior through
param
primitives. e.g.
def model(data):
loc = numpyro.param('loc', 0.)
scale = numpyro.param('scale', 0.5, constraint=constraints.positive)
return numpyro.sample('obs', dist.Normal(loc, scale), obs=data)
Primitives / Effect Handlers
- module primitive to support JAX style neural network. See VAE example.
- condition handler for conditioning sample sites to observed data.
- scale handler for rescaling the log probability score.
Optimizers
JAX optimizers are wrapped in the numpyro.optim module, so that the optimizers can be passed in directly to SVI
.
Distributions
- New distributions: Delta, GaussianRandomWalk, InverseGamma, LKJCholesky (with both
cvine
andonion
methods for sampling), MultivariateNormal. - New transforms: CorrCholeskyTransform (which is vectorized), InverseAutoregressiveTransform, LowerCholeskyTransform, PermuteTransform, PowerTransform.
Utilities
- predictive utility for vectorized predictions from the posterior predictive distribution.
Autoguides
An experimental autoguide module, with more autoguides to come.
New Examples
- Sparse Linear Regression - fast Bayesian discovery of pairwise interactions in high dimensional data.
- Gaussian Process - sample from the posterior over the hyperparameters of a gaussian process.
- HMC on Neal's Funnel - automatic reparameterization through transform distributions.
Enhancements and Bug Fixes
- Improve compiling time in MCMC.
- Better PRNG splitting mechanism in SVI (to avoid reusing PRNG keys).
- Correctly handle models with dynamically changing distribution constraints. e.g.
def model():
x = numpyro.sample('x', dist.Uniform(0., 2.))
y = numpyro.sample('y', dist.Uniform(0., x)) # y's support is not static.
- Fixes
step_size
gettingNaN
in MCMC when it becomes extremely small.
Assets
2
v0.1.0
neerajprad released this
Refer to the README for details.
Assets
2
热门度与活跃度
2.0
8.3
Watchers：33 |
Star：668 |
Fork：54 |
创建时间： 2019-02-14 05:13:59 |
最后Commits： 16天前 |
许可协议：Apache-2.0 |
e1433ff
Compare
Verified
fehiepsi released this
Jul 27, 2020
Breaking Changes
find_heuristic_step_size=True
.reparam
handler. See the eight schools example for the new usage pattern.numpyro.contrib.autoguide
to the main inference modulenumpyro.infer.autoguide
.mask_array
arg is renamed tomask
.scale_factor
arg is renamed toscale
.param_map
is renamed todata
.MultivariateAffineTransform
transform is renamed to LowerCholeskyAffine.init_to_prior
strategy is renamed to init_to_sample.New Features
for
loop, considering usingscan
instead to improve compiling time.is_discrete
,has_enumerate_support
, and new methods shape, enumerate_support, expand, expand_by, mask. In addition,Distribution
has been registered as a JAX Pytree class, with corresponding methodstree_flatten
andtree_unflatten
.batch_ndims
arg to Predictive and log_likelihood to allow using those utilities with arbitrary number of batch dimensions.New Examples
Enhancements and Bug Fixes
plate
statements. #555scale
handler anddeterministic
primitive. #577numpyro.optim
classes. #603numpyro.plate
. #616scan
primitive and the usage ofPredictive
for forecasting. #608 #657Thanks Nikolaos @daydreamt, Daniel Sheldon @dsheldon, Lukas Prediger @lumip, Freddy Boulton @freddyaboulton, Wouter van Amsterdam @vanAmsterdam, and many others for their contributions and helpful feedback!