From f4f43d1ad52d7563810f00dd5e8057b9bf9e7c6c Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 26 Feb 2024 10:54:54 +0100 Subject: [PATCH] Add sphinx documentation (#22) --- README.md | 2 +- docs/.gitignore | 6 ++ docs/Makefile | 22 ++++ docs/_static/theme.css | 48 +++++++++ docs/conf.py | 69 +++++++++++++ docs/examples.rst | 7 ++ docs/index.rst | 89 ++++++++++++++++ docs/requirements.txt | 20 ++++ docs/sbijax.nn.rst | 27 +++++ docs/sbijax.rst | 55 ++++++++++ examples/bivariate_gaussian_snl.py | 2 +- examples/bivariate_gaussian_snp.py | 2 +- examples/bivariate_gaussian_snr.py | 2 +- pyproject.toml | 2 +- sbijax/__init__.py | 2 +- sbijax/_src/_sbi_base.py | 9 -- sbijax/_src/_sne_base.py | 31 ++++-- sbijax/_src/nn/make_flows.py | 35 +++++-- sbijax/_src/nn/make_resnet.py | 6 +- sbijax/_src/nn/make_snass_networks.py | 12 +-- sbijax/_src/nn/snass_net.py | 6 +- sbijax/_src/nn/snasss_net.py | 8 +- sbijax/_src/snass.py | 63 ++++++++---- sbijax/_src/snasss.py | 50 +++++---- sbijax/_src/snl.py | 115 ++++++++++----------- sbijax/_src/snp.py | 54 +++++++--- sbijax/_src/snr.py | 140 +++++++++++++++++--------- 27 files changed, 671 insertions(+), 213 deletions(-) create mode 100644 docs/.gitignore create mode 100644 docs/Makefile create mode 100644 docs/_static/theme.css create mode 100644 docs/conf.py create mode 100644 docs/examples.rst create mode 100644 docs/index.rst create mode 100644 docs/requirements.txt create mode 100644 docs/sbijax.nn.rst create mode 100644 docs/sbijax.rst diff --git a/README.md b/README.md index 4731c36..e72514a 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ pip install git+https://github.com/dirmeier/sbijax@ ## Acknowledgements -> 📝 The package draws significant inspiration from the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more +> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more feature-complete and user-friendly, and better documented. ## Author diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..af6aca3 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,6 @@ +source/examples/ +source/examples/* +build/ +build/* +_autosummary/ +_autosummary/* diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..0ad1935 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,22 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = ./ +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + rm -rf build + rm -rf source/examples + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/theme.css b/docs/_static/theme.css new file mode 100644 index 0000000..1ba652c --- /dev/null +++ b/docs/_static/theme.css @@ -0,0 +1,48 @@ +html[data-theme="light"] { + --pst-color-primary: rgb(121, 40, 161); + --pst-color-primary-bg: #ffe9dd; + --pst-color-secondary: #b26679; + --pst-color-inline-code-links: #b26679; +} + +h1 > code > span { + font-family: var(--pst-font-family-monospace); + font-weight: 700; +} + +nav > li > a > code.literal { + padding-top: 0; + padding-bottom: 0; + background-color: white; + border: 0; +} + +nav.bd-links p.caption { + text-transform: uppercase; +} + +code.literal { + background-color: white; + border: 0; + border-radius: 0; +} +a:hover { + text-decoration-thickness: 1px !important; +} + + +ul.bd-breadcrumbs li.breadcrumb-item a:hover { + text-decoration-thickness: 1px; +} + +nav.bd-links li > a:hover { + text-decoration-thickness: 1px; +} + +.prev-next-area a p.prev-next-title { + text-decoration: none !important; +} + +button.theme-switch-button { + display: none !important; +} diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..8fc74e0 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,69 @@ +from datetime import date + +project = "sbijax" +copyright = f"{date.today().year}, the sbijax developers" +author = "the sbijax developers" + +extensions = [ + "nbsphinx", + "sphinx.ext.autodoc", + 'sphinx_autodoc_typehints', + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", + "sphinx_copybutton", + "sphinx_math_dollar", + "IPython.sphinxext.ipython_console_highlighting", + 'sphinx_design', +] + + +templates_path = ["_templates"] +html_static_path = ["_static"] +html_css_files = ['theme.css'] + +autodoc_default_options = { + "member-order": "bysource", + "special-members": True, + "exclude-members": "__repr__, __str__, __weakref__", +} + +exclude_patterns = [ + "_build", + "build", + "Thumbs.db", + ".DS_Store", + "notebooks/.ipynb_checkpoints", + "examples/*ipynb", + "examples/*py" +] + +autodoc_typehints = "both" + +html_theme = "sphinx_book_theme" + +html_theme_options = { + "repository_url": "https://github.com/dirmeier/sbijax", + "use_repository_button": True, + "use_download_button": False, + "extra_navbar": "" +} + +html_title = "sbijax 🚀" + + +def skip(app, what, name, obj, would_skip, options): + if name == "__init__": + return True + return would_skip + + +def setup(app): + app.connect("autodoc-skip-member", skip) + + +bibtex_bibfiles = ['references.bib'] diff --git a/docs/examples.rst b/docs/examples.rst new file mode 100644 index 0000000..e18cc83 --- /dev/null +++ b/docs/examples.rst @@ -0,0 +1,7 @@ +More examples +============= + +.. note:: + + Self-contained example code can be found on GitHub in `examples `_. + The examples are executable from the command line, so forking/cloning the code suffices to run them. diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..e80c3c1 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,89 @@ +:github_url: https://github.com/dirmeier/sbijax + +👋 Welcome to :code:`sbijax`! +============================= + +.. div:: sd-text-left sd-font-italic + + Simulation-based inference in JAX + +---- + +:code:`sbijax` implements several algorithms for simulation-based inference in +`JAX `_ using `Haiku `_, +`Distrax `_ and `BlackJAX `_. Specifically, :code:`sbijax` implements + +- `Sequential Monte Carlo ABC `_ (:code:`SMCABC`), +- `Neural Likelihood Estimation `_ (:code:`SNL`) +- `Surjective Neural Likelihood Estimation `_ (:code:`SSNL`) +- `Neural Posterior Estimation C `_ (short :code:`SNP`) +- `Contrastive Neural Ratio Estimation `_ (short :code:`SNR`) +- `Neural Approximate Sufficient Statistics `_ (:code:`SNASS`) +- `Neural Approximate Slice Sufficient Statistics `_ (:code:`SNASSS`) + +Installation +------------ + +To install from PyPI, call: + +.. code-block:: bash + + pip install sbijax + +To install the latest GitHub , just call the following on the +command line: + +.. code-block:: bash + + pip install git+https://github.com/dirmeier/sbijax@ + +See also the installation instructions for `JAX `_, if +you plan to use :code:`sbijax` on GPU/TPU. + +Contributing +------------ + +Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled +`"good first issue" `_. + +In order to contribute: + +1) Clone :code:`sbijax` and install :code:`hatch` via :code:`pip install hatch`, +2) create a new branch locally :code:`git checkout -b feature/my-new-feature` or :code:`git checkout -b issue/fixes-bug`, +3) implement your contribution and ideally a test case, +4) test it by calling :code:`hatch run test` on the (Unix) command line, +5) submit a PR 🙂 + +Acknowledgements +---------------- + +.. note:: + + 📝 The API of the package is heavily inspired by the excellent Pytorch-based `sbi `_ package which is + substantially more feature-complete and user-friendly. + +License +------- + +:code:`sbijax` is licensed under the Apache 2.0 License. + +.. toctree:: + :maxdepth: 1 + :hidden: + + 🏠 Home + +.. toctree:: + :caption: 🎓 Examples + :maxdepth: 1 + :hidden: + + Self-contained scripts + +.. toctree:: + :caption: 🧱 API + :maxdepth: 2 + :hidden: + + sbijax + sbijax.nn diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..28f12fe --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,20 @@ +-e . +ipykernel +ipython +matplotlib +nbsphinx +pandas +scikit-learn +seaborn +session_info +sphinx +sphinx-autobuild +sphinx-book-theme>=1.1.0 +sphinx-copybutton +sphinx-math-dollar +sphinx_autodoc_typehints +sphinx_design +sphinx_fontawesome +sphinx_gallery +sphinxcontrib-fulltoc +tqdm diff --git a/docs/sbijax.nn.rst b/docs/sbijax.nn.rst new file mode 100644 index 0000000..362ec1d --- /dev/null +++ b/docs/sbijax.nn.rst @@ -0,0 +1,27 @@ +``sbijax.nn`` +============= + +.. currentmodule:: sbijax.nn + +---- + +``sbijax.nn`` contains utility functions and classes to construct neural +networks and normalizing flows. + +.. autosummary:: + make_affine_maf + make_surjective_affine_maf + make_resnet + make_snass_net + make_snasss_net + + +.. autofunction:: make_affine_maf + +.. autofunction:: make_surjective_affine_maf + +.. autofunction:: make_resnet + +.. autofunction:: make_snass_net + +.. autofunction:: make_snasss_net diff --git a/docs/sbijax.rst b/docs/sbijax.rst new file mode 100644 index 0000000..f6d07ae --- /dev/null +++ b/docs/sbijax.rst @@ -0,0 +1,55 @@ +``sbijax`` +========== + +.. currentmodule:: sbijax + +---- + +:code:`sbijax` contains the implemented methods for simulation-based inference. + +Methods +------- + +.. autosummary:: + SMCABC + SNL + SNP + SNR + SNASS + SNASSS + +SMCABC +~~~~~~ + +.. autoclass:: SMCABC + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SNL+SSNL +~~~~~~~~ + +.. autoclass:: SNL + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SNP +~~~ + +.. autoclass:: SNP + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SNR +~~~ + +.. autoclass:: SNR + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SNASS +~~~~~ + +.. autoclass:: SNASS + :members: fit, simulate_data_and_possibly_append, sample_posterior + +SNASSS +~~~~~~ + +.. autoclass:: SNASSS + :members: fit, simulate_data_and_possibly_append, sample_posterior diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index b41407a..445d760 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -1,5 +1,5 @@ """ -Example using sequential neural likelihood estimation on a bivariate Gaussian +Example using sequential neural likelihood estimation on a bivariate Gaussian """ from functools import partial diff --git a/examples/bivariate_gaussian_snp.py b/examples/bivariate_gaussian_snp.py index 2ad31b7..2ff6338 100644 --- a/examples/bivariate_gaussian_snp.py +++ b/examples/bivariate_gaussian_snp.py @@ -1,5 +1,5 @@ """ -Example using sequential neural posterior estimation on a bivariate Gaussian +Example using sequential neural posterior estimation on a bivariate Gaussian """ import distrax diff --git a/examples/bivariate_gaussian_snr.py b/examples/bivariate_gaussian_snr.py index 2da0eba..856ea24 100644 --- a/examples/bivariate_gaussian_snr.py +++ b/examples/bivariate_gaussian_snr.py @@ -1,5 +1,5 @@ """ -Example using sequential neural ratio estimation on a bivariate Gaussian +Example using sequential neural ratio estimation on a bivariate Gaussian """ import distrax diff --git a/pyproject.toml b/pyproject.toml index 2cc200f..6e50aef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ [tool.hatch.envs.test.scripts] lint = 'pylint sbijax' -test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax' +test = 'pytest -v --cov=./sbijax --cov-report=xml sbijax' [tool.black] line-length = 80 diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 7cfc29f..0de09a8 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.6" +__version__ = "0.1.7" from sbijax._src.abc.smc_abc import SMCABC diff --git a/sbijax/_src/_sbi_base.py b/sbijax/_src/_sbi_base.py index 91aaf0e..61eddb4 100644 --- a/sbijax/_src/_sbi_base.py +++ b/sbijax/_src/_sbi_base.py @@ -17,12 +17,3 @@ def __init__(self, model_fns): self.prior_sampler_fn, self.prior_log_density_fn = model_fns[0] self.simulator_fn = model_fns[1] self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123))) - - @abc.abstractmethod - def sample_posterior(self, rng_key, **kwargs): - """Sample from the posterior distribution. - - Args: - rng_key: a random key - kwargs: keyword arguments with sampler specific parameters - """ diff --git a/sbijax/_src/_sne_base.py b/sbijax/_src/_sne_base.py index c69e17a..0d4ada0 100644 --- a/sbijax/_src/_sne_base.py +++ b/sbijax/_src/_sne_base.py @@ -1,3 +1,4 @@ +import abc from abc import ABC import chex @@ -17,8 +18,11 @@ def __init__(self, model_fns, network): """Construct an SNE object. Args: - model_fns: tuple - network: maf + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: a neural network """ super().__init__(model_fns) self.model = network @@ -60,6 +64,18 @@ def simulate_data_and_possibly_append( d_new = self.stack_data(data, new_data) return d_new, diagnostics + @abc.abstractmethod + def sample_posterior(self, rng_key, params, observable, *args, **kwargs): + """Sample from the approximate posterior. + + Args: + rng_key: a jax random key + params: a pytree of neural network parameters + observable: a data point + *args: argument list + **kwargs: keyword arguments + """ + def simulate_data( self, rng_key, @@ -138,19 +154,20 @@ def stack_data(data, also_data): *[jnp.vstack([a, b]) for a, b in zip(data, also_data)] ) + @staticmethod def as_iterators( - self, rng_key, data, batch_size, percentage_data_as_validation_set + rng_key, data, batch_size, percentage_data_as_validation_set ): """Convert the data set to an iterable for training. Args: - rng_key: random key - data: tuple - batch_size: integer + rng_key: a jax random key + data: a tuple with 'y' and 'theta' elements + batch_size: the size of each batch percentage_data_as_validation_set: fraction Returns: - a batch iterator + two batch iterators """ return as_batch_iterators( rng_key, diff --git a/sbijax/_src/nn/make_flows.py b/sbijax/_src/nn/make_flows.py index 2ab3ca4..a041313 100644 --- a/sbijax/_src/nn/make_flows.py +++ b/sbijax/_src/nn/make_flows.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, List +from typing import Callable, Iterable import distrax import haiku as hk @@ -44,10 +44,15 @@ def make_affine_maf( ): """Create an affine masked autoregressive flow. + The MAFs use `n_layers` layers and are parameterized using MADE networks + with `hidden_sizes` neurons per layer. + Args: n_dimension: dimensionality of data n_layers: number of normalizing flow layers - hidden_sizes: sizes of hidden layers for each normalizing flow + hidden_sizes: sizes of hidden layers for each normalizing flow. E.g., + when the hidden sizes are a tuple (64, 64), then each maf layer + uses a MADE with two layers of size 64 each activation: a jax activation function Returns: @@ -88,24 +93,34 @@ def _flow(method, **kwargs): def make_surjective_affine_maf( n_dimension: int, - n_layer_dimensions: List[int], - n_layers: int = 5, + n_layer_dimensions: Iterable[int], hidden_sizes: Iterable[int] = (64, 64), activation: Callable = jax.nn.tanh, ): """Create a surjective affine masked autoregressive flow. + The MAFs use `n_layers` layers and are parameterized using MADE networks + with `hidden_sizes` neurons per layer. For each dimensionality reducing + layer, a conditional Gaussian density is used that uses the same number of + layer and nodes per layers as `hidden_sizes`. The argument + `n_layer_dimensions` determines which layer is dimensionality-preserving + or -reducing. For example, for `n_layer_dimensions=(5, 5, 3, 3)` and + `n_dimension=5`, the third layer would reduce the dimensionality by two + and use a surjection layer. THe other layers are dimensionality-preserving. + Args: n_dimension: a list of integers that determine the dimensionality of each flow layer n_layer_dimensions: list of integers that determine if a layer is dimensionality-preserving or -reducing - n_layers: number of normalizing flow layers hidden_sizes: sizes of hidden layers for each normalizing flow activation: a jax activation function + Examples: + >>> make_surjective_affine_maf(10, (10, 10, 5, 5, 5)) + Returns: - a normalizing flow model + a surjective normalizing flow model """ @hk.without_apply_rng @@ -114,9 +129,7 @@ def _flow(method, **kwargs): layers = [] order = jnp.arange(n_dimension) curr_dim = n_dimension - for i, n_dim_curr_layer in zip( - range(n_layers[:-1]), n_layer_dimensions[:-1] - ): + for i, n_dim_curr_layer in enumerate(n_layer_dimensions): # layer is dimensionality preserving if n_dim_curr_layer == curr_dim: layer = MaskedAutoregressive( @@ -135,10 +148,10 @@ def _flow(method, **kwargs): n_latent = n_dim_curr_layer layer = AffineMaskedAutoregressiveInferenceFunnel( n_latent, - _decoder_fn(curr_dim - n_latent, hidden_sizes), + _decoder_fn(curr_dim - n_latent, list(hidden_sizes)), conditioner=MADE( n_latent, - hidden_sizes + [n_dim_curr_layer * 2], + list(hidden_sizes) + [n_dim_curr_layer * 2], 2, w_init=hk.initializers.TruncatedNormal(0.001), b_init=jnp.zeros, diff --git a/sbijax/_src/nn/make_resnet.py b/sbijax/_src/nn/make_resnet.py index bfe163f..386feb9 100644 --- a/sbijax/_src/nn/make_resnet.py +++ b/sbijax/_src/nn/make_resnet.py @@ -81,9 +81,9 @@ def make_resnet( n_layers: int = 2, hidden_size: int = 64, activation: Callable = jax.nn.tanh, - dropout_rate=0.2, - do_batch_norm=False, - batch_norm_decay=0.2, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.2, ): """Create a resnet. diff --git a/sbijax/_src/nn/make_snass_networks.py b/sbijax/_src/nn/make_snass_networks.py index 81dda4a..db7cc83 100644 --- a/sbijax/_src/nn/make_snass_networks.py +++ b/sbijax/_src/nn/make_snass_networks.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, Iterable import haiku as hk import jax @@ -8,8 +8,8 @@ def make_snass_net( - summary_net_dimensions: List[int], - critic_net_dimensions: List[int], + summary_net_dimensions: Iterable[int], + critic_net_dimensions: Iterable[int], activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, ): """Create a critic network for SNASS. @@ -43,9 +43,9 @@ def _net(method, **kwargs): def make_snasss_net( - summary_net_dimensions: List[int], - sec_summary_net_dimensions: List[int], - critic_net_dimensions: List[int], + summary_net_dimensions: Iterable[int], + sec_summary_net_dimensions: Iterable[int], + critic_net_dimensions: Iterable[int], activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, ): """Create a critic network for SNASSS. diff --git a/sbijax/_src/nn/snass_net.py b/sbijax/_src/nn/snass_net.py index 25c72bd..735d619 100644 --- a/sbijax/_src/nn/snass_net.py +++ b/sbijax/_src/nn/snass_net.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, Iterable import haiku as hk import jax @@ -12,8 +12,8 @@ class SNASSNet(hk.Module): def __init__( self, - summary_net_dimensions: List[int] = None, - critic_net_dimensions: List[int] = None, + summary_net_dimensions: Iterable[int] = None, + critic_net_dimensions: Iterable[int] = None, summary_net: Callable = None, critic_net: Callable = None, ): diff --git a/sbijax/_src/nn/snasss_net.py b/sbijax/_src/nn/snasss_net.py index 5bc3626..051cc36 100644 --- a/sbijax/_src/nn/snasss_net.py +++ b/sbijax/_src/nn/snasss_net.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, Iterable import haiku as hk import jax @@ -14,9 +14,9 @@ class SNASSSNet(SNASSNet): def __init__( self, - summary_net_dimensions: List[int] = None, - sec_summary_net_dimensions: List[int] = None, - critic_net_dimensions: List[int] = None, + summary_net_dimensions: Iterable[int] = None, + sec_summary_net_dimensions: Iterable[int] = None, + critic_net_dimensions: Iterable[int] = None, summary_net: Callable = None, sec_summary_net: Callable = None, critic_net: Callable = None, diff --git a/sbijax/_src/snass.py b/sbijax/_src/snass.py index 3692de9..6d8b663 100644 --- a/sbijax/_src/snass.py +++ b/sbijax/_src/snass.py @@ -33,21 +33,36 @@ def _jsd_summary_loss(params, rng, apply_fn, **batch): class SNASS(SNL): """Sequential neural approximate summary statistics. + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the likelihood function of summary statistics, i.e., + the modelled dimensionality is that of the summaries + snass_net: a SNASSNet object + References: .. [1] Chen, Yanzhi et al. "Neural Approximate Sufficient Statistics for Implicit Models". ICLR, 2021 """ - def __init__(self, model_fns, density_estimator, snass_net): + def __init__(self, model_fns, density_estimator, summary_net): """Construct a SNASS object. Args: - model_fns: tuple - density_estimator: maf - snass_net: mlp + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the likelihood function of summary statistics, i.e., + the modelled dimensionality is that of the summaries + summary_net: a SNASSNet object """ super().__init__(model_fns, density_estimator) - self.sc_net = snass_net + self.sc_net = summary_net # pylint: disable=arguments-differ,too-many-locals def fit( @@ -64,7 +79,7 @@ def fit( """Fit a SNASS model. Args: - rng_seq: a hk.PRNGSequence + rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object @@ -74,12 +89,15 @@ def fit( that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation - kwargs: keyword arguments with sampler specific parameters. For - sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: tuple of parameters and a tuple of the training information @@ -218,22 +236,25 @@ def sample_posterior( r"""Sample from the approximate posterior. Args: - rng_key: a random key - params: a pytree of parameter for the model + rng_key: a jax random key + params: a pytree of neural network parameters observable: observation to condition on n_chains: number of MCMC chains n_samples: number of samples per chain n_warmup: number of samples to discard - kwargs: keyword arguments with sampler specific parameters. For - sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: an array of samples from the posterior distribution of dimension - (n_samples \times p) + (n_samples \times p) and posterior diagnostics """ observable = jnp.atleast_2d(observable) summary = self.sc_net.apply( diff --git a/sbijax/_src/snasss.py b/sbijax/_src/snasss.py index 96f0043..cedfeb4 100644 --- a/sbijax/_src/snasss.py +++ b/sbijax/_src/snasss.py @@ -57,6 +57,16 @@ def _jsd_summary_loss(params, rng_key, apply_fn, **batch): class SNASSS(SNL): """Sequential neural approximate slice sufficient statistics. + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the likelihood function of summary statistics, i.e., + the modelled dimensionality is that of the summaries + summary_net: a SNASSSNet object + References: .. [1] Yanzhi Chen et al. "Is Learning Summary Statistics Necessary for Likelihood-free Inference". ICML, 2023 @@ -66,9 +76,14 @@ def __init__(self, model_fns, density_estimator, summary_net): """Construct a SNASSS object. Args: - model_fns: tuple - density_estimator: maf - summary_net: snass network + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the likelihood function of summary statistics, i.e., + the modelled dimensionality is that of the summaries + summary_net: a SNASSSNet object """ super().__init__(model_fns, density_estimator) self.sc_net = summary_net @@ -88,7 +103,7 @@ def fit( """Fit a SNASSS model. Args: - rng_key: a hk.PRNGSequence + rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object @@ -98,12 +113,6 @@ def fit( that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation - kwargs: keyword arguments with sampler specific parameters. For - sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval Returns: tuple of parameters and a tuple of the training information @@ -242,22 +251,25 @@ def sample_posterior( r"""Sample from the approximate posterior. Args: - rng_key: a random key - params: a pytree of parameter for the model + rng_key: a jax random key + params: a pytree of neural network parameters observable: observation to condition on n_chains: number of MCMC chains n_samples: number of samples per chain n_warmup: number of samples to discard - kwargs: keyword arguments with sampler specific parameters. For - sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: an array of samples from the posterior distribution of dimension - (n_samples \times p) + (n_samples \times p) and posterior diagnostics """ observable = jnp.atleast_2d(observable) summary = self.sc_net.apply( diff --git a/sbijax/_src/snl.py b/sbijax/_src/snl.py index 0f47caf..2b3e155 100644 --- a/sbijax/_src/snl.py +++ b/sbijax/_src/snl.py @@ -8,15 +8,9 @@ from jax import numpy as jnp from jax import random as jr +from sbijax._src import mcmc from sbijax._src._sne_base import SNE -from sbijax._src.mcmc import ( - mcmc_diagnostics, - sample_with_nuts, - sample_with_slice, -) -from sbijax._src.mcmc.irmh import sample_with_imh -from sbijax._src.mcmc.mala import sample_with_mala -from sbijax._src.mcmc.rmh import sample_with_rmh +from sbijax._src.mcmc import mcmc_diagnostics from sbijax._src.util.early_stopping import EarlyStopping @@ -24,7 +18,15 @@ class SNL(SNE): """Sequential neural likelihood. - Implements SNL and SSNL estimation methods. + Implements both SNL and SSNL estimation methods. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the likelihood function References: .. [1] Papamakarios, George, et al. "Sequential neural likelihood: @@ -36,6 +38,20 @@ class SNL(SNE): arXiv preprint arXiv:2308.01054, 2023. """ + # pylint: disable=useless-parent-delegation + def __init__(self, model_fns, density_estimator): + """Construct a SNL object. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the likelihood function + """ + super().__init__(model_fns, density_estimator) + # pylint: disable=arguments-differ,too-many-locals def fit( self, @@ -48,10 +64,10 @@ def fit( n_early_stopping_patience=10, **kwargs, ): - """Fit a SNL model. + """Fit a SNL or SSNL model. Args: - rng_key: a hk.PRNGSequence + rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object @@ -61,15 +77,9 @@ def fit( that is used for valitation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation - kwargs: keyword arguments with sampler specific parameters. - For slice sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval Returns: - returns a tuple of parameters and a tuple of the training + a tuple of parameters and a tuple of the training information """ itr_key, rng_key = jr.split(rng_key) @@ -186,17 +196,19 @@ def simulate_data_and_possibly_append( n_simulations: number of newly simulated data n_chains: number of MCMC chains n_samples: number of sa les to draw in total - n_warmup: number of draws to discared - kwargs: keyword arguments - dictionary of ey value pairs passed to `sample_posterior`. - The following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps (int) - - n_doubling: number of doubling steps of the interval (int) - - step_size: step size of the initial interval (float) + n_warmup: number of draws to discarded + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: - returns a NamedTuple of two axis, y and theta + returns a NamedTuple with two elements, y and theta """ return super().simulate_data_and_possibly_append( rng_key=rng_key, @@ -224,22 +236,25 @@ def sample_posterior( r"""Sample from the approximate posterior. Args: - rng_key: a random key - params: a pytree of parameter for the model + rng_key: a jax random key + params: a pytree of neural network parameters observable: observation to condition on n_chains: number of MCMC chains n_samples: number of samples per chain n_warmup: number of samples to discard - kwargs: keyword arguments with sampler specific parameters. For - slice sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: an array of samples from the posterior distribution of dimension - (n_samples \times p) + (n_samples \times p) and posterior diagnostics """ observable = jnp.atleast_2d(observable) return self._sample_posterior( @@ -278,40 +293,20 @@ def _joint_logdensity_fn(theta): return jnp.sum(lp) + jnp.sum(lp_prior) if "sampler" in kwargs and kwargs["sampler"] == "slice": - kwargs.pop("sampler", None) def lp__(theta): return jax.vmap(_joint_logdensity_fn)(theta) - sampling_fn = sample_with_slice - elif "sampler" in kwargs and kwargs["sampler"] == "rmh": - kwargs.pop("sampler", None) - - def lp__(theta): - return _joint_logdensity_fn(**theta) - - sampling_fn = sample_with_rmh - elif "sampler" in kwargs and kwargs["sampler"] == "imh": - kwargs.pop("sampler", None) - - def lp__(theta): - return _joint_logdensity_fn(**theta) - - sampling_fn = sample_with_imh - elif "sampler" in kwargs and kwargs["sampler"] == "mala": - kwargs.pop("sampler", None) - - def lp__(theta): - return _joint_logdensity_fn(**theta) - - sampling_fn = sample_with_mala + sampler = kwargs.pop("sampler", None) else: def lp__(theta): return _joint_logdensity_fn(**theta) - sampling_fn = sample_with_nuts + # take whatever sampler is or per default nuts + sampler = kwargs.pop("sampler", "nuts") + sampling_fn = getattr(mcmc, "sample_with_" + sampler) samples = sampling_fn( rng_key=rng_key, lp=lp__, diff --git a/sbijax/_src/snp.py b/sbijax/_src/snp.py index 67613ec..56410d0 100644 --- a/sbijax/_src/snp.py +++ b/sbijax/_src/snp.py @@ -16,20 +16,47 @@ class SNP(SNE): """Sequential neural posterior estimation. + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the posterior distribution + num_atoms: number of atomic atoms + + Examples: + >>> import distrax + >>> from sbijax import SNP + >>> from sbijax.nn import make_affine_maf + >>> + >>> prior = distrax.Normal(0.0, 1.0) + >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) + >>> fns = (prior.sample, prior.log_prob), s + >>> flow = make_affine_maf() + >>> + >>> snr = SNP(fns, flow) + References: .. [1] Greenberg, David, et al. "Automatic posterior transformation for likelihood-free inference." International Conference on Machine Learning, 2019. """ - def __init__(self, model_fns, density_estimator): + def __init__(self, model_fns, density_estimator, num_atoms=10): """Construct an SNP object. Args: - model_fns: tuple - density_estimator: maf + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the posterior distribution + num_atoms: number of atomic atoms """ super().__init__(model_fns, density_estimator) + self.num_atoms = num_atoms self.n_round = 0 # pylint: disable=arguments-differ,too-many-locals @@ -37,28 +64,27 @@ def fit( self, rng_key, data, + *, optimizer=optax.adam(0.0003), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, - n_atoms=10, **kwargs, ): - """Fit an SNPE model. + """Fit an SNP model. Args: - rng_key: a hk.PRNGSequence + rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated - data that is used for valitation and early stopping + data that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement - of training the flow before stopping optimisation - n_atoms: number of atoms to approximate the proposal posterior + of training the flow before stopping optimisation\ Returns: a tuple of parameters and a tuple of the training information @@ -74,7 +100,7 @@ def fit( optimizer=optimizer, n_iter=n_iter, n_early_stopping_patience=n_early_stopping_patience, - n_atoms=n_atoms, + n_atoms=self.num_atoms, ) return params, losses @@ -237,14 +263,14 @@ def sample_posterior( r"""Sample from the approximate posterior. Args: - rng_key: a random key - params: a pytree of parameter for the model + rng_key: a jax random key + params: a pytree of neural network parameters observable: observation to condition on n_samples: number of samples to draw Returns: - an array of samples from the posterior distribution of dimension - (n_samples \times p) + returns an array of samples from the posterior distribution of + dimension (n_samples \times p) """ observable = jnp.atleast_2d(observable) diff --git a/sbijax/_src/snr.py b/sbijax/_src/snr.py index 4a24bf9..1fe8985 100644 --- a/sbijax/_src/snr.py +++ b/sbijax/_src/snr.py @@ -1,12 +1,14 @@ # Parts of this codebase have been adopted from https://github.com/bkmi/cnre - from functools import partial +from typing import Callable, NamedTuple, Optional, Tuple import chex import jax import numpy as np import optax from absl import logging +from haiku import Params +from jax import Array from jax import numpy as jnp from jax import random as jr from jax import scipy as jsp @@ -86,21 +88,51 @@ def _loss(params, rng_key, model, gamma, num_classes, **batch): # pylint: disable=too-many-arguments,unused-argument class SNR(SNE): - """Sequential (contrastive) neural ratio estimation. + r"""Sequential (contrastive) neural ratio estimation. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + classifier: a neural network for classification + num_classes: number of classes to classify against + gamma: relative weight of classes + + Examples: + >>> import distrax + >>> from sbijax import SNR + >>> from sbijax.nn import make_resnet + >>> + >>> prior = distrax.Normal(0.0, 1.0) + >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) + >>> fns = (prior.sample, prior.log_prob), s + >>> resnet = make_resnet() + >>> + >>> snr = SNR(fns, resnet) References: .. [1] Miller, Benjamin K., et al. "Contrastive neural ratio estimation." Advances in Neural Information Processing Systems, 2022. """ - def __init__(self, model_fns, classifier, num_classes=10, gamma=1.0): - """Construct an SNP object. + def __init__( + self, + model_fns: Tuple[Tuple[Callable, Callable], Callable], + classifier: Callable, + num_classes: int = 10, + gamma: float = 1.0, + ): + """Construct an SNR object. Args: - model_fns: tuple + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. classifier: a neural network for classification - num_classes: int - gamma: float + num_classes: number of classes to classify against + gamma: relative weight of classes """ super().__init__(model_fns, classifier) self.gamma = gamma @@ -109,30 +141,29 @@ def __init__(self, model_fns, classifier, num_classes=10, gamma=1.0): # pylint: disable=arguments-differ,too-many-locals def fit( self, - rng_key, - data, + rng_key: Array, + data: NamedTuple, *, - optimizer=optax.adam(0.003), - n_iter=1000, - batch_size=100, - percentage_data_as_validation_set=0.1, - n_early_stopping_patience=10, + optimizer: optax.GradientTransformation = optax.adam(0.003), + n_iter: int = 1000, + batch_size: int = 100, + percentage_data_as_validation_set: float = 0.1, + n_early_stopping_patience: float = 10, **kwargs, ): - """Fit an SNPE model. + """Fit an SNR model. Args: - rng_key: a hk.PRNGSequence + rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated - data that is used for valitation and early stopping + data that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation - n_atoms: number of atoms to approximate the proposal posterior Returns: a tuple of parameters and a tuple of the training information @@ -142,7 +173,7 @@ def fit( itr_key, data, batch_size, percentage_data_as_validation_set ) params, losses = self._fit_model_single_round( - seed=rng_key, + rng_key=rng_key, train_iter=train_iter, val_iter=val_iter, optimizer=optimizer, @@ -155,14 +186,14 @@ def fit( # pylint: disable=undefined-loop-variable def _fit_model_single_round( self, - seed, + rng_key, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience, ): - init_key, seed = jr.split(seed) + init_key, rng_key = jr.split(rng_key) params = self._init_params(init_key, **train_iter(0)) state = optimizer.init(params) @@ -183,7 +214,7 @@ def step(params, rng, state, **batch): logging.info("training model") for i in tqdm(range(n_iter)): train_loss = 0.0 - rng_key = jr.fold_in(seed, i) + rng_key = jr.fold_in(rng_key, i) for j in range(train_iter.num_batches): train_key, rng_key = jr.split(rng_key) batch = train_iter(j) @@ -232,34 +263,40 @@ def body_fn(rng_key, **batch): def simulate_data_and_possibly_append( self, - rng_key, - params=None, - observable=None, - data=None, - n_simulations=1_000, - n_chains=4, - n_samples=2_000, - n_warmup=1_000, + rng_key: Array, + params: Optional[Params] = None, + observable: Array = None, + data: NamedTuple = None, + n_simulations: int = 1_000, + n_chains: int = 4, + n_samples: int = 2_000, + n_warmup: int = 1_000, **kwargs, ): """Simulate data from the prior or posterior. + Simulate new parameters and observables from the prior or posterior + (when params and data given). If a data argument is provided, append + the new samples to the data set and return the old+new data. + Args: - rng_key: a random key + rng_key: a jax random key params: a dictionary of neural network parameters observable: an observation - data: existing data set + data: existing data set or None n_simulations: number of newly simulated data n_chains: number of MCMC chains n_samples: number of sa les to draw in total - n_warmup: number of draws to discared - kwargs: keyword arguments - dictionary of ey value pairs passed to `sample_posterior`. - The following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps (int) - - n_doubling: number of doubling steps of the interval (int) - - step_size: step size of the initial interval (float) + n_warmup: number of draws to discarded + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: returns a NamedTuple of two axis, y and theta @@ -290,22 +327,25 @@ def sample_posterior( r"""Sample from the approximate posterior. Args: - rng_key: a random key - params: a pytree of parameter for the model + rng_key: a jax random key + params: a pytree of neural network parameters observable: observation to condition on n_chains: number of MCMC chains n_samples: number of samples per chain n_warmup: number of samples to discard - kwargs: keyword arguments with sampler specific parameters. For - slice sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval + + Keyword Args: + sampler (str): either 'nuts', 'slice' or None (defaults to nuts) + n_thin (int): number of thinning steps + (only used if sampler='slice') + n_doubling (int): number of doubling steps of the interval + (only used if sampler='slice') + step_size (float): step size of the initial interval + (only used if sampler='slice') Returns: - an array of samples from the posterior distribution of dimension - (n_samples \times p) + returns an array of samples from the posterior distribution of + dimension (n_samples \times p) and posterior diagnostics """ observable = jnp.atleast_2d(observable) return self._sample_posterior(