From 3f7edcb3d7ec4aa58bf1a94977bd65cbd267ec3a Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 22 Apr 2024 18:59:33 +0200 Subject: [PATCH] Add minor functionality (#29) * Update makefile * Add new functions for sampling --- Makefile | 14 ++++++++++ pyproject.toml | 2 +- sbijax/__init__.py | 2 +- sbijax/_src/_sne_base.py | 59 ++++++++++++++++++++++++++++++++-------- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index f8802c2..841556a 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,19 @@ +.PHONY: tag +.PHONY: tests +.PHONY: lints +.PHONY: docs + PKG_VERSION=`hatch version` tag: git tag -a v${PKG_VERSION} -m v${PKG_VERSION} git push --tag + +tests: + hatch run test:test + +lints: + hatch run test:lint + +docs: + cd docs && make html diff --git a/pyproject.toml b/pyproject.toml index 09845a8..8bb1c38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ dynamic = ["version"] [project.urls] -homepage = "https://github.com/dirmeier/sbijax" +Homepage = "https://github.com/dirmeier/sbijax" [tool.hatch.metadata] allow-direct-references = true diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 0ed0e43..f667a6b 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -1,6 +1,6 @@ """sbijax: Simulation-based inference in JAX.""" -__version__ = "0.2.0" +__version__ = "0.2.0.post0" from sbijax._src.abc.smc_abc import SMCABC from sbijax._src.scmpe import SCMPE diff --git a/sbijax/_src/_sne_base.py b/sbijax/_src/_sne_base.py index 4e834cf..d8a1b49 100644 --- a/sbijax/_src/_sne_base.py +++ b/sbijax/_src/_sne_base.py @@ -37,7 +37,7 @@ def simulate_data_and_possibly_append( n_simulations=1000, **kwargs, ): - """Simulate data from the prior or posterior and append. + """Simulate data and paarameters from the prior or posterior and append. Args: rng_key: a random key @@ -76,7 +76,7 @@ def sample_posterior(self, rng_key, params, observable, *args, **kwargs): **kwargs: keyword arguments """ - def simulate_data( + def simulate_parameters( self, rng_key, *, @@ -85,14 +85,14 @@ def simulate_data( n_simulations=1000, **kwargs, ): - r"""Simulate data from the posterior or prior and append. + r"""Simulate parameters from the posterior or prior. Args: rng_key: a random key params:a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized - posterior using 'observable; - observable: an observation. Needs to be gfiven if posterior draws + posterior using 'observable'. + observable: an observation. Needs to be given if posterior draws are desired n_simulations: number of newly simulated data kwargs: dictionary of ey value pairs passed to `sample_posterior` @@ -100,12 +100,11 @@ def simulate_data( Returns: a NamedTuple of two axis, y and theta """ - sample_key, rng_key = jr.split(rng_key) if params is None or len(params) == 0: diagnostics = None self.n_total_simulations += n_simulations new_thetas = self.prior_sampler_fn( - seed=sample_key, + seed=rng_key, sample_shape=(n_simulations,), ) else: @@ -117,7 +116,7 @@ def simulate_data( if "n_samples" not in kwargs: kwargs["n_samples"] = n_simulations new_thetas, diagnostics = self.sample_posterior( - rng_key=sample_key, + rng_key=rng_key, params=params, observable=jnp.atleast_2d(observable), **kwargs, @@ -126,15 +125,53 @@ def simulate_data( new_thetas = jr.permutation(perm_key, new_thetas) new_thetas = new_thetas[:n_simulations, :] - simulate_key, rng_key = jr.split(rng_key) - new_obs = self.simulator_fn(seed=simulate_key, theta=new_thetas) + return new_thetas, diagnostics + + def simulate_data( + self, + rng_key, + *, + params=None, + observable=None, + n_simulations=1000, + **kwargs, + ): + r"""Simulate data from the posterior or prior and append. + + Args: + rng_key: a random key + params:a dictionary of neural network parameters. If None, will + draw from prior. If parameters given, will draw from amortized + posterior using 'observable; + observable: an observation. Needs to be gfiven if posterior draws + are desired + n_simulations: number of newly simulated data + kwargs: dictionary of ey value pairs passed to `sample_posterior` + + Returns: + a NamedTuple of two axis, y and theta + """ + theta_key, data_key = jr.split(rng_key) + + new_thetas, diagnostics = self.simulate_parameters( + theta_key, + params=params, + observable=observable, + n_simulations=n_simulations, + **kwargs, + ) + + new_obs = self.simulate_observations(data_key, new_thetas) chex.assert_shape(new_thetas, [n_simulations, None]) chex.assert_shape(new_obs, [n_simulations, None]) - new_data = named_dataset(new_obs, new_thetas) return new_data, diagnostics + def simulate_observations(self, rng_key, thetas): + new_obs = self.simulator_fn(seed=rng_key, theta=thetas) + return new_obs + @staticmethod def as_iterators( rng_key, data, batch_size, percentage_data_as_validation_set