Skip to content

Commit

Permalink
Improve documentation (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 28, 2024
1 parent 78e8097 commit 9017ab2
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 16 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@
[JAX](https://github.com/google/jax) using [Haiku](https://github.com/deepmind/dm-haiku),
[Distrax](https://github.com/deepmind/distrax) and [BlackJAX](https://github.com/blackjax-devs/blackjax). Specifically, `sbijax` implements

- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`),
- [Sequential Monte Carlo ABC](https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728) (`SMCABC`)
- [Neural Likelihood Estimation](https://arxiv.org/abs/1805.07226) (`SNL`)
- [Surjective Neural Likelihood Estimation](https://arxiv.org/abs/2308.01054) (`SSNL`)
- [Neural Posterior Estimation C](https://arxiv.org/abs/1905.07488) (short `SNP`)
- [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`)
- [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`)
- [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`)
- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`)

where the acronyms in parentheses denote the names of the methods in `sbijax`.

> [!CAUTION]
> ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.
## Examples

You can find several self-contained examples on how to use the algorithms in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).
Expand Down
7 changes: 6 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
`JAX <https://github.com/google/jax>`_ using `Haiku <https://github.com/deepmind/dm-haiku>`_,
`Distrax <https://github.com/deepmind/distrax>`_ and `BlackJAX <https://github.com/blackjax-devs/blackjax>`_. Specifically, :code:`sbijax` implements

- `Sequential Monte Carlo ABC <https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728>`_ (:code:`SMCABC`),
- `Sequential Monte Carlo ABC <https://www.routledge.com/Handbook-of-Approximate-Bayesian-Computation/Sisson-Fan-Beaumont/p/book/9780367733728>`_ (:code:`SMCABC`)
- `Neural Likelihood Estimation <https://arxiv.org/abs/1805.07226>`_ (:code:`SNL`)
- `Surjective Neural Likelihood Estimation <https://arxiv.org/abs/2308.01054>`_ (:code:`SSNL`)
- `Neural Posterior Estimation C <https://arxiv.org/abs/1905.07488>`_ (short :code:`SNP`)
- `Contrastive Neural Ratio Estimation <https://arxiv.org/abs/2210.06170>`_ (short :code:`SNR`)
- `Neural Approximate Sufficient Statistics <https://arxiv.org/abs/2010.10079>`_ (:code:`SNASS`)
- `Neural Approximate Slice Sufficient Statistics <https://openreview.net/forum?id=jjzJ768iV1>`_ (:code:`SNASSS`)
- `Flow matching posterior estimation <https://arxiv.org/abs/2305.17161>`_ (:code:`SFMPE`)

.. caution::

⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Installation
------------
Expand Down
3 changes: 3 additions & 0 deletions docs/sbijax.nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ networks and normalizing flows.
make_affine_maf
make_surjective_affine_maf
make_resnet
make_ccnf
make_snass_net
make_snasss_net

Expand All @@ -22,6 +23,8 @@ networks and normalizing flows.

.. autofunction:: make_resnet

.. autofunction:: make_ccnf

.. autofunction:: make_snass_net

.. autofunction:: make_snasss_net
7 changes: 7 additions & 0 deletions docs/sbijax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Methods
SNL
SNP
SNR
SFMPE
SNASS
SNASSS

Expand Down Expand Up @@ -42,6 +43,12 @@ SNR
.. autoclass:: SNR
:members: fit, simulate_data_and_possibly_append, sample_posterior

SFMPE
~~~~~

.. autoclass:: SFMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SNASS
~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import numpy as jnp
from jax import random as jr

from sbijax import FMPE
from sbijax import SFMPE
from sbijax.nn import CCNF


Expand Down Expand Up @@ -45,7 +45,7 @@ def run():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = FMPE(fns, make_model(2))
estim = SFMPE(fns, make_model(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
Expand Down
4 changes: 2 additions & 2 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.8"
__version__ = "0.1.9"


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.fmpe import FMPE
from sbijax._src.fmpe import SFMPE
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
Expand Down
21 changes: 13 additions & 8 deletions sbijax/_src/fmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tqdm import tqdm

from sbijax._src._sne_base import SNE
from sbijax._src.nn.continuous_normalizing_flow import CCNF
from sbijax._src.util.early_stopping import EarlyStopping


Expand Down Expand Up @@ -59,8 +58,14 @@ def _cfm_loss(


# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
class FMPE(SNE):
"""Flow matching posterior estimation.
class SFMPE(SNE):
r"""Sequential flow matching posterior estimation.
Implements a sequential version of the FMPE algorithm introduced in [1]_.
For all rounds $r > 1$ parameter samples
:math:`\theta \sim \hat{p}^r(\theta)` are drawn from
the approximate posterior instead of the prior when computing the flow
matching loss.
Args:
model_fns: a tuple of tuples. The first element is a tuple that
Expand All @@ -71,23 +76,23 @@ class FMPE(SNE):
Examples:
>>> import distrax
>>> from sbijax import SNP
>>> from sbijax.nn import make_affine_maf
>>> from sbijax import SFMPE
>>> from sbijax.nn import make_ccnf
>>>
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> flow = make_affine_maf()
>>> flow = make_ccnf(1)
>>>
>>> snr = SNP(fns, flow)
>>> estim = SFMPE(fns, flow)
References:
.. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable
Simulation-Based Inference." Advances in Neural Information
Processing Systems, 2024.
"""

def __init__(self, model_fns, density_estimator: CCNF):
def __init__(self, model_fns, density_estimator):
"""Construct a FMPE object.
Args:
Expand Down
4 changes: 2 additions & 2 deletions sbijax/_src/snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class SNP(SNE):
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> flow = make_affine_maf()
>>> flow = make_affine_maf(1)
>>>
>>> snr = SNP(fns, flow)
>>> estim = SNP(fns, flow)
References:
.. [1] Greenberg, David, et al. "Automatic posterior transformation for
Expand Down

0 comments on commit 9017ab2

Please sign in to comment.