From db2a588c1cc313236401c3ac7d7e06c2a130f048 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 26 Feb 2024 20:53:58 +0100 Subject: [PATCH] DRAFT: Move to TF iterators (#23) * Move all classes to use TF iterators * Update unit tests --- examples/bivariate_gaussian_snl.py | 9 +- examples/bivariate_gaussian_snr.py | 3 +- examples/slcp_snass.py | 11 +- examples/slcp_ssnl.py | 35 ++-- pyproject.toml | 2 + sbijax/__init__.py | 2 +- sbijax/_src/_sne_base.py | 24 +-- sbijax/_src/generator.py | 107 ------------- sbijax/_src/mcmc/__init__.py | 2 +- .../_src/mcmc/{sample.py => diagnostics.py} | 0 sbijax/_src/snass.py | 52 +++--- sbijax/_src/snasss.py | 149 ++---------------- sbijax/_src/snl.py | 19 ++- sbijax/_src/snl_test.py | 30 ---- sbijax/_src/snp.py | 18 +-- sbijax/_src/snr.py | 9 +- sbijax/_src/util/data.py | 20 +++ sbijax/_src/util/data_test.py | 91 +++++++++++ sbijax/_src/util/dataloader.py | 103 ++++++++++++ 19 files changed, 304 insertions(+), 382 deletions(-) delete mode 100644 sbijax/_src/generator.py rename sbijax/_src/mcmc/{sample.py => diagnostics.py} (100%) create mode 100644 sbijax/_src/util/data.py create mode 100644 sbijax/_src/util/data_test.py create mode 100644 sbijax/_src/util/dataloader.py diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index 445d760..68cca5d 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -22,7 +22,7 @@ from surjectors.util import unstack from sbijax import SNL -from sbijax.mcmc import sample_with_slice +from sbijax.mcmc import sample_with_nuts def prior_model_fns(): @@ -84,9 +84,6 @@ def _flow(method, **kwargs): def run(): y_observed = jnp.array([-2.0, 1.0]) - log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: jax.vmap(log_density_partial)(x) - prior_simulator_fn, prior_logdensity_fn = prior_model_fns() fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn @@ -107,7 +104,9 @@ def run(): ) sample_key, rng_key = jr.split(jr.PRNGKey(123)) - slice_samples = sample_with_slice( + log_density_partial = partial(log_density_fn, y=y_observed) + log_density = lambda x: log_density_partial(**x) + slice_samples = sample_with_nuts( sample_key, log_density, prior_simulator_fn ) slice_samples = slice_samples.reshape(-1, 2) diff --git a/examples/bivariate_gaussian_snr.py b/examples/bivariate_gaussian_snr.py index 856ea24..897152f 100644 --- a/examples/bivariate_gaussian_snr.py +++ b/examples/bivariate_gaussian_snr.py @@ -43,7 +43,7 @@ def run(): optimizer = optax.adam(1e-3) data, params = None, {} - for i in range(5): + for i in range(2): data, _ = snr.simulate_data_and_possibly_append( jr.fold_in(jr.PRNGKey(1), i), params=params, @@ -54,7 +54,6 @@ def run(): jr.fold_in(jr.PRNGKey(2), i), data=data, optimizer=optimizer, - batch_size=100, ) rng_key = jr.PRNGKey(23) diff --git a/examples/slcp_snass.py b/examples/slcp_snass.py index b8400b4..322340b 100644 --- a/examples/slcp_snass.py +++ b/examples/slcp_snass.py @@ -20,8 +20,8 @@ from surjectors.nn import MADE from surjectors.util import unstack -from sbijax import SNASSS -from sbijax.nn import make_snasss_net +from sbijax import SNASS +from sbijax._src.nn.make_snass_networks import make_snass_net def prior_model_fns(): @@ -125,15 +125,15 @@ def run(): prior_simulator_fn, prior_logdensity_fn = prior_model_fns() fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - estim = SNASSS( + estim = SNASS( fns, make_model(5), - make_snasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1]), + make_snass_net([64, 64, 5], [64, 64, 1]), ) optimizer = optax.adam(1e-3) data, params = None, {} - for i in range(5): + for i in range(2): data, _ = estim.simulate_data_and_possibly_append( jr.fold_in(jr.PRNGKey(12), i), params=params, @@ -144,7 +144,6 @@ def run(): jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer, - batch_size=100, ) rng_key = jr.PRNGKey(23) diff --git a/examples/slcp_ssnl.py b/examples/slcp_ssnl.py index 662ebdd..74db420 100644 --- a/examples/slcp_ssnl.py +++ b/examples/slcp_ssnl.py @@ -15,7 +15,6 @@ from jax import numpy as jnp from jax import random as jr from jax import scipy as jsp -from jax import vmap from surjectors import ( AffineMaskedAutoregressiveInferenceFunnel, Chain, @@ -27,7 +26,7 @@ from surjectors.util import unstack from sbijax import SNL -from sbijax.mcmc import sample_with_slice +from sbijax.mcmc import sample_with_nuts def prior_model_fns(): @@ -160,30 +159,15 @@ def _flow(method, **kwargs): def run(use_surjectors): len_theta = 5 - # this is the thetas used in SNL - # thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6]) - y_observed = jnp.array( - [ - [ - -0.9707123, - -2.9461224, - -0.4494722, - -3.4231849, - -0.13285634, - -3.364017, - -0.85367596, - -2.4271638, - ] - ] + thetas = jnp.linspace(-2.0, 2.0, len_theta) + y_0 = simulator_fn(jr.PRNGKey(0), thetas.reshape(-1, len_theta)).reshape( + -1, 8 ) - log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: vmap(log_density_partial)(x) - prior_simulator_fn, prior_fn = prior_model_fns() fns = (prior_simulator_fn, prior_fn), simulator_fn - snl = SNL(fns, make_model(y_observed.shape[1], use_surjectors)) + snl = SNL(fns, make_model(y_0.shape[1], use_surjectors)) optimizer = optax.adam(1e-3) data, params = None, {} @@ -191,19 +175,20 @@ def run(use_surjectors): data, _ = snl.simulate_data_and_possibly_append( jr.fold_in(jr.PRNGKey(12), i), params=params, - observable=y_observed, + observable=y_0, data=data, - sampler="slice", ) params, info = snl.fit( jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer ) sample_key, rng_key = jr.split(jr.PRNGKey(123)) - snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed) + snl_samples, _ = snl.sample_posterior(sample_key, params, y_0) sample_key, rng_key = jr.split(rng_key) - slice_samples = sample_with_slice( + log_density_partial = partial(log_density_fn, y=y_0) + log_density = lambda x: log_density_partial(**x) + slice_samples = sample_with_nuts( sample_key, log_density, prior_simulator_fn, diff --git a/pyproject.toml b/pyproject.toml index 6e50aef..60701b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "optax>=0.1.3", "surjectors>=0.3.0", "tfp-nightly>=0.20.0.dev20230404", + "tensorflow==2.15.0", + "tensorflow-datasets==4.9.3", "tqdm>=4.64.1" ] dynamic = ["version"] diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 0de09a8..2bf3b76 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.7" +__version__ = "0.1.8" from sbijax._src.abc.smc_abc import SMCABC diff --git a/sbijax/_src/_sne_base.py b/sbijax/_src/_sne_base.py index 0d4ada0..4e25833 100644 --- a/sbijax/_src/_sne_base.py +++ b/sbijax/_src/_sne_base.py @@ -6,7 +6,8 @@ from jax import random as jr from sbijax._src._sbi_base import SBI -from sbijax._src.generator import as_batch_iterators, named_dataset +from sbijax._src.util.data import stack_data +from sbijax._src.util.dataloader import as_batch_iterators, named_dataset # pylint: disable=too-many-arguments,unused-argument @@ -61,7 +62,7 @@ def simulate_data_and_possibly_append( if data is None: d_new = new_data else: - d_new = self.stack_data(data, new_data) + d_new = stack_data(data, new_data) return d_new, diagnostics @abc.abstractmethod @@ -135,25 +136,6 @@ def simulate_data( return new_data, diagnostics - @staticmethod - def stack_data(data, also_data): - """Stack two data sets. - - Args: - data: one data set - also_data: another data set - - Returns: - returns the stack of the two data sets - """ - if data is None: - return also_data - if also_data is None: - return data - return named_dataset( - *[jnp.vstack([a, b]) for a, b in zip(data, also_data)] - ) - @staticmethod def as_iterators( rng_key, data, batch_size, percentage_data_as_validation_set diff --git a/sbijax/_src/generator.py b/sbijax/_src/generator.py deleted file mode 100644 index e36ed29..0000000 --- a/sbijax/_src/generator.py +++ /dev/null @@ -1,107 +0,0 @@ -from collections import namedtuple - -import chex -from jax import lax -from jax import numpy as jnp -from jax import random as jr - -named_dataset = namedtuple("named_dataset", "y theta") - - -# pylint: disable=missing-class-docstring,too-few-public-methods -class DataLoader: - # noqa: D101 - def __init__( - self, num_batches, idxs=None, get_batch=None, batches=None - ): # noqa: D107 - self.num_batches = num_batches - self.idxs = idxs - if idxs is not None: - self.num_samples = len(idxs) - else: - self.num_samples = self.num_batches * batches[0]["y"].shape[0] - self.get_batch = get_batch - self.batches = batches - - def __call__(self, idx, idxs=None): # noqa: D102 - if self.batches is not None: - return self.batches[idx] - - if idxs is None: - idxs = self.idxs - return self.get_batch(idx, idxs) - - -# pylint: disable=missing-function-docstring -def as_batch_iterators( - rng_key: chex.PRNGKey, data: named_dataset, batch_size, split, shuffle -): - """Create two data batch iterators from a data set. - - Args: - rng_key: random key - data: a named tuple containing all dat - batch_size: batch size - split: fraction of data to use for training data set. Rest is used - for validation data set. - shuffle: shuffle the data set or no - - Returns: - two iterators - """ - n = data.y.shape[0] - n_train = int(n * split) - - if shuffle: - idxs = jr.permutation(rng_key, jnp.arange(n)) - data = named_dataset(*[el[idxs] for _, el in enumerate(data)]) - - y_train = named_dataset(*[el[:n_train] for el in data]) - y_val = named_dataset(*[el[n_train:] for el in data]) - train_rng_key, val_rng_key = jr.split(rng_key) - - train_itr = as_batch_iterator(train_rng_key, y_train, batch_size, shuffle) - val_itr = as_batch_iterator(val_rng_key, y_val, batch_size, shuffle) - - return train_itr, val_itr - - -# pylint: disable=missing-function-docstring -def as_batch_iterator( - rng_key: chex.PRNGKey, data: named_dataset, batch_size, shuffle -): - """Create a data batch iterator from a data set. - - Args: - rng_key: random key - data: a named tuple containing all dat - batch_size: batch size - shuffle: shuffle the data set or no - - Returns: - an iterator - """ - n = data.y.shape[0] - if n < batch_size: - num_batches = 1 - batch_size = n - elif n % batch_size == 0: - num_batches = int(n // batch_size) - else: - num_batches = int(n // batch_size) + 1 - - idxs = jnp.arange(n) - if shuffle: - idxs = jr.permutation(rng_key, idxs) - - def get_batch(idx, idxs=idxs): - start_idx = idx * batch_size - step_size = jnp.minimum(n - start_idx, batch_size) - ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size) - batch = { - name: lax.index_take(array, (ret_idx,), axes=(0,)) - for name, array in zip(data._fields, data) - } - return batch - - return DataLoader(num_batches, idxs, get_batch) diff --git a/sbijax/_src/mcmc/__init__.py b/sbijax/_src/mcmc/__init__.py index 3ddef25..0e5a4b8 100644 --- a/sbijax/_src/mcmc/__init__.py +++ b/sbijax/_src/mcmc/__init__.py @@ -1,6 +1,6 @@ +from sbijax._src.mcmc.diagnostics import mcmc_diagnostics from sbijax._src.mcmc.irmh import sample_with_imh from sbijax._src.mcmc.mala import sample_with_mala from sbijax._src.mcmc.nuts import sample_with_nuts from sbijax._src.mcmc.rmh import sample_with_rmh -from sbijax._src.mcmc.sample import mcmc_diagnostics from sbijax._src.mcmc.slice import sample_with_slice diff --git a/sbijax/_src/mcmc/sample.py b/sbijax/_src/mcmc/diagnostics.py similarity index 100% rename from sbijax/_src/mcmc/sample.py rename to sbijax/_src/mcmc/diagnostics.py diff --git a/sbijax/_src/snass.py b/sbijax/_src/snass.py index 6d8b663..fe2170a 100644 --- a/sbijax/_src/snass.py +++ b/sbijax/_src/snass.py @@ -6,9 +6,10 @@ from absl import logging from jax import numpy as jnp from jax import random as jr +from tqdm import tqdm -from sbijax._src.generator import DataLoader from sbijax._src.snl import SNL +from sbijax._src.util.dataloader import as_batch_iterator, named_dataset from sbijax._src.util.early_stopping import EarlyStopping @@ -76,7 +77,7 @@ def fit( n_early_stopping_patience=10, **kwargs, ): - """Fit a SNASS model. + """Fit the model to data. Args: rng_key: a jax random key @@ -116,8 +117,13 @@ def fit( n_early_stopping_patience=n_early_stopping_patience, ) - train_iter = self._as_summary(train_iter, snet_params) - val_iter = self._as_summary(val_iter, snet_params) + train_key, val_key, rng_key = jr.split(rng_key, 3) + train_iter = self._as_itr_over_summaries( + train_key, train_iter, snet_params, batch_size + ) + val_iter = self._as_itr_over_summaries( + val_key, val_iter, snet_params, batch_size + ) nde_params, losses = self._fit_model_single_round( seed=rng_key, @@ -133,19 +139,19 @@ def fit( snet_losses, ) - def _as_summary(self, iters, params): - @jax.jit - def as_batch(y, theta): - return { - "y": self.sc_net.apply(params, method="summary", y=y), - "theta": theta, - } - - return DataLoader( - num_batches=iters.num_batches, - batches=[as_batch(**iters(i)) for i in range(iters.num_batches)], + def _as_itr_over_summaries(self, rng_key, iters, params, batch_size): + ys = [] + thetas = [] + for batch in iters: + ys.append(self.sc_net.apply(params, method="summary", y=batch["y"])) + thetas.append(batch["theta"]) + ys = jnp.vstack(ys) + thetas = jnp.vstack(thetas) + return as_batch_iterator( + rng_key, named_dataset(ys, thetas), batch_size, True ) + # pylint: disable=undefined-loop-variable def _fit_summary_net( self, rng_key, @@ -157,7 +163,9 @@ def _fit_summary_net( ): init_key, rng_key = jr.split(rng_key) - params = self._init_summary_net_params(init_key, **train_iter(0)) + params = self._init_summary_net_params( + init_key, **next(iter(train_iter)) + ) state = optimizer.init(params) loss_fn = jax.jit( partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) @@ -174,11 +182,10 @@ def step(rng, params, state, **batch): early_stop = EarlyStopping(1e-3, n_early_stopping_patience) best_params, best_loss = None, np.inf logging.info("training summary net") - for i in range(n_iter): + for i in tqdm(range(n_iter)): train_loss = 0.0 epoch_key, rng_key = jr.split(rng_key) - for j in range(train_iter.num_batches): - batch = train_iter(j) + for j, batch in enumerate(train_iter): batch_loss, params, state = step( jr.fold_in(epoch_key, j), params, state, **batch ) @@ -211,15 +218,14 @@ def _summary_validation_loss(self, params, rng_key, val_iter): partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) ) - def body_fn(i, batch_key): - batch = val_iter(i) + def body_fn(batch_key, **batch): loss = loss_fn(params, batch_key, **batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) losses = 0.0 - for i in range(val_iter.num_batches): + for batch in val_iter: batch_key, rng_key = jr.split(rng_key) - losses += body_fn(i, batch_key) + losses += body_fn(batch_key, **batch) return losses def sample_posterior( diff --git a/sbijax/_src/snasss.py b/sbijax/_src/snasss.py index cedfeb4..17a34d0 100644 --- a/sbijax/_src/snasss.py +++ b/sbijax/_src/snasss.py @@ -7,8 +7,7 @@ from jax import numpy as jnp from jax import random as jr -from sbijax._src.generator import DataLoader -from sbijax._src.snl import SNL +from sbijax._src.snass import SNASS from sbijax._src.util.early_stopping import EarlyStopping @@ -54,7 +53,7 @@ def _jsd_summary_loss(params, rng_key, apply_fn, **batch): # pylint: disable=too-many-arguments,unused-argument -class SNASSS(SNL): +class SNASSS(SNASS): """Sequential neural approximate slice sufficient statistics. Args: @@ -72,6 +71,7 @@ class SNASSS(SNL): Likelihood-free Inference". ICML, 2023 """ + # pylint: disable=useless-parent-delegation def __init__(self, model_fns, density_estimator, summary_net): """Construct a SNASSS object. @@ -85,82 +85,9 @@ def __init__(self, model_fns, density_estimator, summary_net): the modelled dimensionality is that of the summaries summary_net: a SNASSSNet object """ - super().__init__(model_fns, density_estimator) - self.sc_net = summary_net - - # pylint: disable=arguments-differ,too-many-locals - def fit( - self, - rng_key, - data, - optimizer=optax.adam(0.0003), - n_iter=1000, - batch_size=128, - percentage_data_as_validation_set=0.1, - n_early_stopping_patience=10, - **kwargs, - ): - """Fit a SNASSS model. - - Args: - rng_key: a jax random key - data: data set obtained from calling - `simulate_data_and_possibly_append` - optimizer: an optax optimizer object - n_iter: maximal number of training iterations per round - batch_size: batch size used for training the model - percentage_data_as_validation_set: percentage of the simulated data - that is used for validation and early stopping - n_early_stopping_patience: number of iterations of no improvement - of training the flow before stopping optimisation - - Returns: - tuple of parameters and a tuple of the training information - """ - itr_key, rng_key = jr.split(rng_key) - train_iter, val_iter = self.as_iterators( - itr_key, data, batch_size, percentage_data_as_validation_set - ) - - snet_params, snet_losses = self._fit_summary_net( - rng_key=rng_key, - train_iter=train_iter, - val_iter=val_iter, - optimizer=optimizer, - n_iter=n_iter, - n_early_stopping_patience=n_early_stopping_patience, - ) - - train_iter = self._as_summary(train_iter, snet_params) - val_iter = self._as_summary(val_iter, snet_params) - - nde_params, losses = self._fit_model_single_round( - seed=rng_key, - train_iter=train_iter, - val_iter=val_iter, - optimizer=optimizer, - n_iter=n_iter, - n_early_stopping_patience=n_early_stopping_patience, - ) - - return {"params": nde_params, "s_params": snet_params}, ( - losses, - snet_losses, - ) - - def _as_summary(self, iters, params): - @jax.jit - def as_batch(y, theta): - return { - "y": self.sc_net.apply(params, method="summary", y=y), - "theta": theta, - } - - return DataLoader( - num_batches=iters.num_batches, - batches=[as_batch(**iters(i)) for i in range(iters.num_batches)], - ) + super().__init__(model_fns, density_estimator, summary_net) + # pylint: disable=undefined-loop-variable def _fit_summary_net( self, rng_key, @@ -172,7 +99,9 @@ def _fit_summary_net( ): init_key, rng_key = jr.split(rng_key) - params = self._init_summary_net_params(init_key, **train_iter(0)) + params = self._init_summary_net_params( + init_key, **next(iter(train_iter)) + ) state = optimizer.init(params) loss_fn = jax.jit( partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) @@ -192,8 +121,7 @@ def step(rng, params, state, **batch): for i in range(n_iter): train_loss = 0.0 epoch_key, rng_key = jr.split(rng_key) - for j in range(train_iter.num_batches): - batch = train_iter(j) + for j, batch in enumerate(train_iter): batch_loss, params, state = step( jr.fold_in(epoch_key, j), params, state, **batch ) @@ -217,70 +145,17 @@ def step(rng, params, state, **batch): losses = jnp.vstack(losses)[: (i + 1), :] return best_params, losses - def _init_summary_net_params(self, rng_key, **init_data): - params = self.sc_net.init(rng_key, method="forward", **init_data) - return params - def _summary_validation_loss(self, params, rng_key, val_iter): loss_fn = jax.jit( partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) ) - def body_fn(i, batch_key): - batch = val_iter(i) + def body_fn(batch_key, **batch): loss = loss_fn(params, batch_key, **batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) losses = 0.0 - for i in range(val_iter.num_batches): + for batch in val_iter: batch_key, rng_key = jr.split(rng_key) - losses += body_fn(i, batch_key) + losses += body_fn(batch_key, **batch) return losses - - def sample_posterior( - self, - rng_key, - params, - observable, - *, - n_chains=4, - n_samples=2_000, - n_warmup=1_000, - **kwargs, - ): - r"""Sample from the approximate posterior. - - Args: - rng_key: a jax random key - params: a pytree of neural network parameters - observable: observation to condition on - n_chains: number of MCMC chains - n_samples: number of samples per chain - n_warmup: number of samples to discard - - Keyword Args: - sampler (str): either 'nuts', 'slice' or None (defaults to nuts) - n_thin (int): number of thinning steps - (only used if sampler='slice') - n_doubling (int): number of doubling steps of the interval - (only used if sampler='slice') - step_size (float): step size of the initial interval - (only used if sampler='slice') - - Returns: - an array of samples from the posterior distribution of dimension - (n_samples \times p) and posterior diagnostics - """ - observable = jnp.atleast_2d(observable) - summary = self.sc_net.apply( - params["s_params"], method="summary", y=observable - ) - return super().sample_posterior( - rng_key, - params["params"], - summary, - n_chains=n_chains, - n_samples=n_samples, - n_warmup=n_warmup, - **kwargs, - ) diff --git a/sbijax/_src/snl.py b/sbijax/_src/snl.py index 2b3e155..6e079d6 100644 --- a/sbijax/_src/snl.py +++ b/sbijax/_src/snl.py @@ -7,6 +7,7 @@ from absl import logging from jax import numpy as jnp from jax import random as jr +from tqdm import tqdm from sbijax._src import mcmc from sbijax._src._sne_base import SNE @@ -59,7 +60,7 @@ def fit( data, optimizer=optax.adam(0.0003), n_iter=1000, - batch_size=128, + batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs, @@ -97,7 +98,7 @@ def fit( return params, losses - # pylint: disable=arguments-differ + # pylint: disable=arguments-differ,undefined-loop-variable def _fit_model_single_round( self, seed, @@ -108,7 +109,7 @@ def _fit_model_single_round( n_early_stopping_patience, ): init_key, seed = jr.split(seed) - params = self._init_params(init_key, **train_iter(0)) + params = self._init_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) @jax.jit @@ -128,10 +129,9 @@ def loss_fn(params): early_stop = EarlyStopping(1e-3, n_early_stopping_patience) best_params, best_loss = None, np.inf logging.info("training model") - for i in range(n_iter): + for i in tqdm(range(n_iter)): train_loss = 0.0 - for j in range(train_iter.num_batches): - batch = train_iter(j) + for batch in train_iter: batch_loss, params, state = step(params, state, **batch) train_loss += batch_loss * ( batch["y"].shape[0] / train_iter.num_samples @@ -158,14 +158,13 @@ def loss_fn(**batch): ) return -jnp.mean(lp) - def body_fn(i): - batch = val_iter(i) + def body_fn(batch): loss = loss_fn(**batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) losses = 0.0 - for i in range(val_iter.num_batches): - losses += body_fn(i) + for batch in val_iter: + losses += body_fn(batch) return losses def _init_params(self, rng_key, **init_data): diff --git a/sbijax/_src/snl_test.py b/sbijax/_src/snl_test.py index ed2d703..efacd66 100644 --- a/sbijax/_src/snl_test.py +++ b/sbijax/_src/snl_test.py @@ -1,5 +1,4 @@ # pylint: skip-file -import chex import distrax import haiku as hk import pytest @@ -92,35 +91,6 @@ def test_snl(): ) -def test_stack_data(): - prior_simulator_fn, prior_logdensity_fn = prior_model_fns() - fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - - snl = SNL(fns, make_model(2)) - n = 100 - data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) - also_data, _ = snl.simulate_data(jr.PRNGKey(2), n_simulations=n) - stacked_data = snl.stack_data(data, also_data) - - chex.assert_trees_all_equal(data[0], stacked_data[0][:n]) - chex.assert_trees_all_equal(data[1], stacked_data[1][:n]) - chex.assert_trees_all_equal(also_data[0], stacked_data[0][n:]) - chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:]) - - -def test_stack_data_with_none(): - prior_simulator_fn, prior_logdensity_fn = prior_model_fns() - fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - - snl = SNL(fns, make_model(2)) - n = 100 - data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) - stacked_data = snl.stack_data(None, data) - - chex.assert_trees_all_equal(data[0], stacked_data[0]) - chex.assert_trees_all_equal(data[1], stacked_data[1]) - - def test_simulate_data_from_posterior_fail(): rng_seq = hk.PRNGSequence(0) diff --git a/sbijax/_src/snp.py b/sbijax/_src/snp.py index 56410d0..154dbb7 100644 --- a/sbijax/_src/snp.py +++ b/sbijax/_src/snp.py @@ -7,6 +7,7 @@ from jax import numpy as jnp from jax import random as jr from jax import scipy as jsp +from tqdm import tqdm from sbijax._src._sne_base import SNE from sbijax._src.util.early_stopping import EarlyStopping @@ -105,6 +106,7 @@ def fit( return params, losses + # pylint: disable=undefined-loop-variable def _fit_model_single_round( self, seed, @@ -116,7 +118,7 @@ def _fit_model_single_round( n_atoms, ): init_key, seed = jr.split(seed) - params = self._init_params(init_key, **train_iter(0)) + params = self._init_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) n_round = self.n_round @@ -155,12 +157,11 @@ def step(params, rng, state, **batch): early_stop = EarlyStopping(1e-3, n_early_stopping_patience) best_params, best_loss = None, np.inf logging.info("training model") - for i in range(n_iter): + for i in tqdm(range(n_iter)): train_loss = 0.0 rng_key = jr.fold_in(seed, i) - for j in range(train_iter.num_batches): + for batch in train_iter: train_key, rng_key = jr.split(rng_key) - batch = train_iter(j) batch_loss, params, state = step( params, train_key, state, **batch ) @@ -182,7 +183,7 @@ def step(params, rng, state, **batch): best_params = params.copy() self.n_round += 1 - losses = jnp.vstack(losses)[:i, :] + losses = jnp.vstack(losses)[: (i + 1), :] return best_params, losses def _init_params(self, rng_key, **init_data): @@ -246,15 +247,14 @@ def loss_fn(rng, **batch): ) return -jnp.mean(lp) - def body_fn(i, rng_key): - batch = val_iter(i) + def body_fn(batch, rng_key): loss = jax.jit(loss_fn)(rng_key, **batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) loss = 0.0 - for i in range(val_iter.num_batches): + for batch in val_iter: val_key, rng_key = jr.split(rng_key) - loss += body_fn(i, val_key) + loss += body_fn(batch, val_key) return loss def sample_posterior( diff --git a/sbijax/_src/snr.py b/sbijax/_src/snr.py index 1fe8985..bf1dc23 100644 --- a/sbijax/_src/snr.py +++ b/sbijax/_src/snr.py @@ -194,7 +194,7 @@ def _fit_model_single_round( n_early_stopping_patience, ): init_key, rng_key = jr.split(rng_key) - params = self._init_params(init_key, **train_iter(0)) + params = self._init_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) loss_fn = partial(_loss, gamma=self.gamma, num_classes=self.num_classes) @@ -215,9 +215,8 @@ def step(params, rng, state, **batch): for i in tqdm(range(n_iter)): train_loss = 0.0 rng_key = jr.fold_in(rng_key, i) - for j in range(train_iter.num_batches): + for batch in train_iter: train_key, rng_key = jr.split(rng_key) - batch = train_iter(j) batch_loss, params, state = step( params, train_key, state, **batch ) @@ -256,9 +255,9 @@ def body_fn(rng_key, **batch): return loss * (batch["y"].shape[0] / val_iter.num_samples) loss = 0.0 - for i in range(val_iter.num_batches): + for batch in val_iter: val_key, rng_key = jr.split(rng_key) - loss += body_fn(val_key, **val_iter(i)) + loss += body_fn(val_key, **batch) return loss def simulate_data_and_possibly_append( diff --git a/sbijax/_src/util/data.py b/sbijax/_src/util/data.py new file mode 100644 index 0000000..1550ece --- /dev/null +++ b/sbijax/_src/util/data.py @@ -0,0 +1,20 @@ +from jax import numpy as jnp + +from sbijax._src.util.dataloader import named_dataset + + +def stack_data(data, also_data): + """Stack two data sets. + + Args: + data: one data set + also_data: another data set + + Returns: + returns the stack of the two data sets + """ + if data is None: + return also_data + if also_data is None: + return data + return named_dataset(*[jnp.vstack([a, b]) for a, b in zip(data, also_data)]) diff --git a/sbijax/_src/util/data_test.py b/sbijax/_src/util/data_test.py new file mode 100644 index 0000000..271948e --- /dev/null +++ b/sbijax/_src/util/data_test.py @@ -0,0 +1,91 @@ +# pylint: skip-file + +import chex +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random as jr +from surjectors import Chain, MaskedCoupling, TransformedDistribution +from surjectors.nn import make_mlp +from surjectors.util import make_alternating_binary_mask + +from sbijax._src.snl import SNL +from sbijax._src.util.data import stack_data + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def make_model(dim): + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _flow(method, **kwargs): + layers = [] + for i in range(2): + mask = make_alternating_binary_mask(dim, i % 2 == 0) + layer = MaskedCoupling( + mask=mask, + bijector_fn=_bijector_fn, + conditioner=make_mlp([8, 8, dim * 2]), + ) + layers.append(layer) + chain = Chain(layers) + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(dim), jnp.ones(dim)), + 1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + td = hk.transform(_flow) + td = hk.without_apply_rng(td) + return td + + +def test_stack_data(): + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snl = SNL(fns, make_model(2)) + n = 100 + data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) + also_data, _ = snl.simulate_data(jr.PRNGKey(2), n_simulations=n) + stacked_data = stack_data(data, also_data) + + chex.assert_trees_all_equal(data[0], stacked_data[0][:n]) + chex.assert_trees_all_equal(data[1], stacked_data[1][:n]) + chex.assert_trees_all_equal(also_data[0], stacked_data[0][n:]) + chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:]) + + +def test_stack_data_with_none(): + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snl = SNL(fns, make_model(2)) + n = 100 + data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) + stacked_data = stack_data(None, data) + + chex.assert_trees_all_equal(data[0], stacked_data[0]) + chex.assert_trees_all_equal(data[1], stacked_data[1]) diff --git a/sbijax/_src/util/dataloader.py b/sbijax/_src/util/dataloader.py new file mode 100644 index 0000000..70a1973 --- /dev/null +++ b/sbijax/_src/util/dataloader.py @@ -0,0 +1,103 @@ +from collections import namedtuple + +import tensorflow as tf +from jax import Array +from jax import numpy as jnp +from jax import random as jr + +named_dataset = namedtuple("named_dataset", "y theta") + + +# pylint: disable=missing-class-docstring,too-few-public-methods +class DataLoader: + # noqa: D101 + def __init__(self, itr, num_samples): # noqa: D107 + self._itr = itr + self.num_samples = num_samples + + def __iter__(self): + """Iterate over the data set.""" + yield from self._itr.as_numpy_iterator() + + +# pylint: disable=missing-function-docstring +def as_batch_iterators( + rng_key: Array, data: named_dataset, batch_size, split, shuffle +): + """Create two data batch iterators from a data set. + + Args: + rng_key: a jax random key + data: a named tuple with elements 'y' and 'theta' all data + batch_size: size of each batch + split: fraction of data to use for training data set. Rest is used + for validation data set. + shuffle: shuffle the data set or no + + Returns: + returns two iterators + """ + n = data.y.shape[0] + n_train = int(n * split) + + if shuffle: + idxs = jr.permutation(rng_key, jnp.arange(n)) + data = named_dataset(*[el[idxs] for _, el in enumerate(data)]) + + y_train = named_dataset(*[el[:n_train] for el in data]) + y_val = named_dataset(*[el[n_train:] for el in data]) + train_rng_key, val_rng_key = jr.split(rng_key) + + train_itr = as_batch_iterator(train_rng_key, y_train, batch_size, shuffle) + val_itr = as_batch_iterator(val_rng_key, y_val, batch_size, shuffle) + + return train_itr, val_itr + + +def as_batched_numpy_iterator_from_tf( + rng_key: Array, data: tf.data.Dataset, iter_size, batch_size, shuffle +): + """Create a data batch iterator from a tensorflow data set. + + Args: + rng_key: a jax random key + data: a named tuple with elements 'y' and 'theta' all data + iter_size: total number of elements in the data set + batch_size: size of each batch + shuffle: shuffle the data set or no + + Returns: + a tensorflow iterator + """ + # hack, cause the tf stuff doesn't support jax keys :) + max_int32 = jnp.iinfo(jnp.int32).max + seed = jr.randint(rng_key, shape=(), minval=0, maxval=max_int32) + data = ( + data.shuffle( + 10 * batch_size, + seed=int(seed), + reshuffle_each_iteration=shuffle, + ) + .batch(batch_size) + .prefetch(buffer_size=batch_size) + ) + return DataLoader(data, iter_size) + + +# pylint: disable=missing-function-docstring +def as_batch_iterator(rng_key: Array, data: named_dataset, batch_size, shuffle): + """Create a data batch iterator from a data set. + + Args: + rng_key: a jax random key + data: a named tuple with elements 'y' and 'theta' all data + batch_size: size of each batch + shuffle: shuffle the data set or no + + Returns: + a tensorflow iterator + """ + itr = tf.data.Dataset.from_tensor_slices(dict(zip(data._fields, data))) + return as_batched_numpy_iterator_from_tf( + rng_key, itr, data[0].shape[0], batch_size, shuffle + )