From 9017ab2881218e09b1f96799693b18349d780b54 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Wed, 28 Feb 2024 15:32:11 +0100 Subject: [PATCH] Improve documentation (#26) --- README.md | 6 +++++- docs/index.rst | 7 ++++++- docs/sbijax.nn.rst | 3 +++ docs/sbijax.rst | 7 +++++++ ...an_fmpe.py => bivariate_gaussian_sfmpe.py} | 4 ++-- sbijax/__init__.py | 4 ++-- sbijax/_src/fmpe.py | 21 ++++++++++++------- sbijax/_src/snp.py | 4 ++-- 8 files changed, 40 insertions(+), 16 deletions(-) rename examples/{bivariate_gaussian_fmpe.py => bivariate_gaussian_sfmpe.py} (96%) diff --git a/README.md b/README.md index e72514a..7aac6bb 100644 --- a/README.md +++ b/README.md @@ -12,16 +12,20 @@ [JAX](https://github.com/google/jax) using [Haiku](https://github.com/deepmind/dm-haiku), [Distrax](https://github.com/deepmind/distrax) and [BlackJAX](https://github.com/blackjax-devs/blackjax). Specifically, `sbijax` implements -- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`), +- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`) - [Neural Likelihood Estimation](https://arxiv.org/abs/1805.07226) (`SNL`) - [Surjective Neural Likelihood Estimation](https://arxiv.org/abs/2308.01054) (`SSNL`) - [Neural Posterior Estimation C](https://arxiv.org/abs/1905.07488) (short `SNP`) - [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`) - [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`) - [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`) +- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`) where the acronyms in parentheses denote the names of the methods in `sbijax`. +> [!CAUTION] +> ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them. + ## Examples You can find several self-contained examples on how to use the algorithms in [examples](https://github.com/dirmeier/sbijax/tree/main/examples). diff --git a/docs/index.rst b/docs/index.rst index e80c3c1..4f3b482 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,13 +13,18 @@ `JAX `_ using `Haiku `_, `Distrax `_ and `BlackJAX `_. Specifically, :code:`sbijax` implements -- `Sequential Monte Carlo ABC `_ (:code:`SMCABC`), +- `Sequential Monte Carlo ABC `_ (:code:`SMCABC`) - `Neural Likelihood Estimation `_ (:code:`SNL`) - `Surjective Neural Likelihood Estimation `_ (:code:`SSNL`) - `Neural Posterior Estimation C `_ (short :code:`SNP`) - `Contrastive Neural Ratio Estimation `_ (short :code:`SNR`) - `Neural Approximate Sufficient Statistics `_ (:code:`SNASS`) - `Neural Approximate Slice Sufficient Statistics `_ (:code:`SNASSS`) +- `Flow matching posterior estimation `_ (:code:`SFMPE`) + +.. caution:: + + ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them. Installation ------------ diff --git a/docs/sbijax.nn.rst b/docs/sbijax.nn.rst index 362ec1d..8717d51 100644 --- a/docs/sbijax.nn.rst +++ b/docs/sbijax.nn.rst @@ -12,6 +12,7 @@ networks and normalizing flows. make_affine_maf make_surjective_affine_maf make_resnet + make_ccnf make_snass_net make_snasss_net @@ -22,6 +23,8 @@ networks and normalizing flows. .. autofunction:: make_resnet +.. autofunction:: make_ccnf + .. autofunction:: make_snass_net .. autofunction:: make_snasss_net diff --git a/docs/sbijax.rst b/docs/sbijax.rst index f6d07ae..332fa0c 100644 --- a/docs/sbijax.rst +++ b/docs/sbijax.rst @@ -15,6 +15,7 @@ Methods SNL SNP SNR + SFMPE SNASS SNASSS @@ -42,6 +43,12 @@ SNR .. autoclass:: SNR :members: fit, simulate_data_and_possibly_append, sample_posterior +SFMPE +~~~~~ + +.. autoclass:: SFMPE + :members: fit, simulate_data_and_possibly_append, sample_posterior + SNASS ~~~~~ diff --git a/examples/bivariate_gaussian_fmpe.py b/examples/bivariate_gaussian_sfmpe.py similarity index 96% rename from examples/bivariate_gaussian_fmpe.py rename to examples/bivariate_gaussian_sfmpe.py index 668b239..bdd4f7b 100644 --- a/examples/bivariate_gaussian_fmpe.py +++ b/examples/bivariate_gaussian_sfmpe.py @@ -10,7 +10,7 @@ from jax import numpy as jnp from jax import random as jr -from sbijax import FMPE +from sbijax import SFMPE from sbijax.nn import CCNF @@ -45,7 +45,7 @@ def run(): prior_simulator_fn, prior_logdensity_fn = prior_model_fns() fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - estim = FMPE(fns, make_model(2)) + estim = SFMPE(fns, make_model(2)) optimizer = optax.adam(1e-3) data, params = None, {} diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 833dd2f..646b951 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,11 +2,11 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.8" +__version__ = "0.1.9" from sbijax._src.abc.smc_abc import SMCABC -from sbijax._src.fmpe import FMPE +from sbijax._src.fmpe import SFMPE from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL diff --git a/sbijax/_src/fmpe.py b/sbijax/_src/fmpe.py index f108e80..d426c68 100644 --- a/sbijax/_src/fmpe.py +++ b/sbijax/_src/fmpe.py @@ -9,7 +9,6 @@ from tqdm import tqdm from sbijax._src._sne_base import SNE -from sbijax._src.nn.continuous_normalizing_flow import CCNF from sbijax._src.util.early_stopping import EarlyStopping @@ -59,8 +58,14 @@ def _cfm_loss( # pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation -class FMPE(SNE): - """Flow matching posterior estimation. +class SFMPE(SNE): + r"""Sequential flow matching posterior estimation. + + Implements a sequential version of the FMPE algorithm introduced in [1]_. + For all rounds $r > 1$ parameter samples + :math:`\theta \sim \hat{p}^r(\theta)` are drawn from + the approximate posterior instead of the prior when computing the flow + matching loss. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -71,15 +76,15 @@ class FMPE(SNE): Examples: >>> import distrax - >>> from sbijax import SNP - >>> from sbijax.nn import make_affine_maf + >>> from sbijax import SFMPE + >>> from sbijax.nn import make_ccnf >>> >>> prior = distrax.Normal(0.0, 1.0) >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) >>> fns = (prior.sample, prior.log_prob), s - >>> flow = make_affine_maf() + >>> flow = make_ccnf(1) >>> - >>> snr = SNP(fns, flow) + >>> estim = SFMPE(fns, flow) References: .. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable @@ -87,7 +92,7 @@ class FMPE(SNE): Processing Systems, 2024. """ - def __init__(self, model_fns, density_estimator: CCNF): + def __init__(self, model_fns, density_estimator): """Construct a FMPE object. Args: diff --git a/sbijax/_src/snp.py b/sbijax/_src/snp.py index 154dbb7..f895a27 100644 --- a/sbijax/_src/snp.py +++ b/sbijax/_src/snp.py @@ -34,9 +34,9 @@ class SNP(SNE): >>> prior = distrax.Normal(0.0, 1.0) >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) >>> fns = (prior.sample, prior.log_prob), s - >>> flow = make_affine_maf() + >>> flow = make_affine_maf(1) >>> - >>> snr = SNP(fns, flow) + >>> estim = SNP(fns, flow) References: .. [1] Greenberg, David, et al. "Automatic posterior transformation for