diff --git a/README.md b/README.md index 7aac6bb..0926110 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ - [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`) - [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`) - [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`) -- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`) +- [Flow matching posterior estimation](https://arxiv.org/abs/2305.17161) (`SFMPE`) +- [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`) where the acronyms in parentheses denote the names of the methods in `sbijax`. @@ -53,6 +54,7 @@ pip install git+https://github.com/dirmeier/sbijax@ ## Acknowledgements +> [!NOTE] > 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more feature-complete and user-friendly, and better documented. diff --git a/docs/index.rst b/docs/index.rst index 4f3b482..b25e991 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,6 +21,7 @@ - `Neural Approximate Sufficient Statistics `_ (:code:`SNASS`) - `Neural Approximate Slice Sufficient Statistics `_ (:code:`SNASSS`) - `Flow matching posterior estimation `_ (:code:`SFMPE`) +- `Consistency model posterior estimation `_ (:code:`SCMPE`) .. caution:: diff --git a/docs/sbijax.rst b/docs/sbijax.rst index 332fa0c..a074306 100644 --- a/docs/sbijax.rst +++ b/docs/sbijax.rst @@ -16,6 +16,7 @@ Methods SNP SNR SFMPE + SCMPE SNASS SNASSS @@ -49,6 +50,12 @@ SFMPE .. autoclass:: SFMPE :members: fit, simulate_data_and_possibly_append, sample_posterior +SCMPE +~~~~~ + +.. autoclass:: SCMPE + :members: fit, simulate_data_and_possibly_append, sample_posterior + SNASS ~~~~~ diff --git a/examples/bivariate_gaussian_cfmpe.py b/examples/bivariate_gaussian_cfmpe.py new file mode 100644 index 0000000..01404e4 --- /dev/null +++ b/examples/bivariate_gaussian_cfmpe.py @@ -0,0 +1,85 @@ +""" +Example using consistency model posterior estimation on a bivariate Gaussian +""" + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr + +from sbijax import SCMPE +from sbijax.nn import ConsistencyModel + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 1.0) + y = theta + p.sample(seed=seed) + return y + + +def make_model(dim): + @hk.transform + def _mlp(method, **kwargs): + def _c_skip(time): + return 1 / ((time - 0.001) ** 2 + 1) + + def _c_out(time): + return 1.0 * (time - 0.001) / jnp.sqrt(1 + time**2) + + def _nn(theta, time, context, **kwargs): + ins = jnp.concatenate([theta, time, context], axis=-1) + outs = hk.nets.MLP([64, 64, dim])(ins) + out_skip = _c_skip(time) * theta + _c_out(time) * outs + return out_skip + + cm = ConsistencyModel(dim, _nn) + return cm(method, **kwargs) + + return _mlp + + +def run(): + y_observed = jnp.array([2.0, -2.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SCMPE(fns, make_model(2)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = estim.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + ) + + rng_key = jr.PRNGKey(23) + post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) + print(post_samples) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(post_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 646b951..75ada41 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,11 +2,12 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.9" +__version__ = "0.2.0" from sbijax._src.abc.smc_abc import SMCABC -from sbijax._src.fmpe import SFMPE +from sbijax._src.scmpe import SCMPE +from sbijax._src.sfmpe import SFMPE from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL diff --git a/sbijax/_src/nn/consistency_model.py b/sbijax/_src/nn/consistency_model.py new file mode 100644 index 0000000..24fc43d --- /dev/null +++ b/sbijax/_src/nn/consistency_model.py @@ -0,0 +1,212 @@ +from typing import Callable + +import distrax +import haiku as hk +import jax +from jax import numpy as jnp + +from sbijax._src.nn.continuous_normalizing_flow import _ResnetBlock + +__all__ = ["ConsistencyModel", "make_consistency_model"] + + +class ConsistencyModel(hk.Module): + """A consistency model. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration + """ + + def __init__( + self, + n_dimension: int, + transform: Callable, + t_min: float = 0.001, + t_max: float = 50.0, + ): + """Construct a consistency model. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration + """ + super().__init__() + self._n_dimension = n_dimension + self._network = transform + self._t_max = t_max + self._t_min = t_min + self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) + + def __call__(self, method, **kwargs): + """Aplpy the flow. + + Args: + method (str): method to call + + Keyword Args: + keyword arguments for the called method: + """ + return getattr(self, method)(**kwargs) + + def sample(self, context, **kwargs): + """Sample from the consistency model. + + Args: + context: array of conditioning variables + kwargs: keyword argumente like 'is_training' + """ + noise = self._base_distribution.sample( + seed=hk.next_rng_key(), sample_shape=(context.shape[0],) + ) + y_hat = self.vector_field(noise, self._t_max, context, **kwargs) + + noise = self._base_distribution.sample( + seed=hk.next_rng_key(), sample_shape=(y_hat.shape[0],) + ) + tme = self._t_min + (self._t_max - self._t_min) / 2 + noise = jnp.sqrt(jnp.square(tme) - jnp.square(self._t_min)) * noise + y_tme = y_hat + noise + y_hat = self.vector_field(y_tme, tme, context, **kwargs) + + return y_hat + + def vector_field(self, theta, time, context, **kwargs): + """Compute the vector field. + + Args: + theta: array of parameters + time: time variables + context: array of conditioning variables + + Keyword Args: + keyword arguments that aer passed tothe neural network + """ + time = jnp.full((theta.shape[0], 1), time) + return self._network(theta=theta, time=time, context=context, **kwargs) + + +# pylint: disable=too-many-arguments,too-many-instance-attributes +class _CMResnet(hk.Module): + """A simplified 1-d residual network.""" + + def __init__( + self, + n_layers: int, + n_dimension: int, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.0, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + t_min: float = 0.001, + sigma_data: float = 1.0, + ): + super().__init__() + self.n_layers = n_layers + self.n_dimension = n_dimension + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + self.sigma_data = sigma_data + self.var_data = self.sigma_data**2 + self.t_min = t_min + + def __call__(self, theta, time, context, is_training, **kwargs): + outputs = context + t_theta_embedding = jnp.concatenate( + [ + hk.Linear(self.n_dimension)(theta), + hk.Linear(self.n_dimension)(time), + ], + axis=-1, + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + for _ in range(self.n_layers): + outputs = _ResnetBlock( + hidden_size=self.hidden_size, + activation=self.activation, + dropout_rate=self.dropout_rate, + do_batch_norm=self.do_batch_norm, + batch_norm_decay=self.batch_norm_decay, + )(outputs, context=t_theta_embedding, is_training=is_training) + outputs = self.activation(outputs) + outputs = hk.Linear(self.n_dimension)(outputs) + + # TODO(simon): dan we choose sigma automatically? + out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs + return out_skip + + def _c_skip(self, time): + return self.var_data / ((time - self.t_min) ** 2 + self.var_data) + + def _c_out(self, time): + return ( + self.sigma_data + * (time - self.t_min) + / jnp.sqrt(self.var_data + time**2) + ) + + +def make_consistency_model( + n_dimension: int, + n_layers: int = 2, + hidden_size: int = 64, + activation: Callable = jax.nn.tanh, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.2, + t_min: float = 0.001, + t_max: float = 50.0, + sigma_data: float = 1.0, +): + """Create a consistency model. + + The consistency model uses a residual network as score network. + + Args: + n_dimension: dimensionality of modelled space + n_layers: number of resnet blocks + hidden_size: sizes of hidden layers for each resnet block + activation: a jax activation function + dropout_rate: dropout rate to use in resnet blocks + do_batch_norm: use batch normalization or not + batch_norm_decay: decay rate of EMA in batch norm layer + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration + sigma_data: the standard deviation of the data :) + + Returns: + returns a consistency model + """ + + @hk.transform + def _cm(method, **kwargs): + nn = _CMResnet( + n_layers=n_layers, + n_dimension=n_dimension, + hidden_size=hidden_size, + activation=activation, + do_batch_norm=do_batch_norm, + dropout_rate=dropout_rate, + batch_norm_decay=batch_norm_decay, + t_min=t_min, + sigma_data=sigma_data, + ) + cm = ConsistencyModel(n_dimension, nn, t_min=t_min, t_max=t_max) + return cm(method, **kwargs) + + return _cm diff --git a/sbijax/_src/nn/continuous_normalizing_flow.py b/sbijax/_src/nn/continuous_normalizing_flow.py index 1c405e0..741d865 100644 --- a/sbijax/_src/nn/continuous_normalizing_flow.py +++ b/sbijax/_src/nn/continuous_normalizing_flow.py @@ -47,7 +47,7 @@ def __call__(self, method, **kwargs): """ return getattr(self, method)(**kwargs) - def sample(self, context): + def sample(self, context, **kwargs): """Sample from the pushforward. Args: @@ -61,7 +61,7 @@ def ode_func(time, theta_t): theta_t = theta_t.reshape(-1, self._n_dimension) time = jnp.full((theta_t.shape[0], 1), time) ret = self.vector_field( - theta=theta_t, time=time, context=context, is_training=False + theta=theta_t, time=time, context=context, **kwargs ) return ret.reshape(-1) diff --git a/sbijax/_src/scmpe.py b/sbijax/_src/scmpe.py new file mode 100644 index 0000000..f9b2d83 --- /dev/null +++ b/sbijax/_src/scmpe.py @@ -0,0 +1,247 @@ +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr +from tqdm import tqdm + +from sbijax._src.sfmpe import SFMPE +from sbijax._src.util.early_stopping import EarlyStopping + + +def _alpha_t(time): + return 1.0 / (_time_schedule(time + 1) - _time_schedule(time)) + + +def _time_schedule(n, rho=7, t_min=0.001, t_max=50, n_inters=1000): + left = t_min ** (1 / rho) + right = t_max ** (1 / rho) - t_min ** (1 / rho) + right = (n - 1) / (n_inters - 1) * right + return (left + right) ** rho + + +def _discretization_schedule(n_iter, max_iter=1000): + s0, s1 = 10, 50 + nk = ( + (n_iter / max_iter) * (jnp.square(s1 + 1) - jnp.square(s0)) + + jnp.square(s0) + - 1 + ) + nk = jnp.ceil(jnp.sqrt(nk)) + 1 + return nk + + +# pylint: disable=too-many-locals,too-many-arguments +def _consistency_loss( + params, + ema_params, + rng_key, + apply_fn, + n_iter, + t_min, + t_max, + is_training=False, + **batch, +): + theta = batch["theta"] + nk = _discretization_schedule(n_iter) + + t_key, rng_key = jr.split(rng_key) + time_idx = jr.randint( + t_key, shape=(theta.shape[0],), minval=1, maxval=nk - 1 + ) + tn = _time_schedule( + time_idx, t_min=t_min, t_max=t_max, n_inters=nk + ).reshape(-1, 1) + tnp1 = _time_schedule( + time_idx + 1, t_min=t_min, t_max=t_max, n_inters=nk + ).reshape(-1, 1) + + noise_key, rng_key = jr.split(rng_key) + noise = jr.normal(noise_key, shape=(*theta.shape,)) + + train_rng, rng_key = jr.split(rng_key) + fnp1 = apply_fn( + params, + train_rng, + method="vector_field", + theta=theta + tnp1 * noise, + time=tnp1, + context=batch["y"], + is_training=is_training, + ) + fn = apply_fn( + ema_params, + train_rng, + method="vector_field", + theta=theta + tn * noise, + time=tn, + context=batch["y"], + is_training=is_training, + ) + mse = jnp.sqrt(jnp.mean(jnp.square(fnp1 - fn), axis=1)) + loss = _alpha_t(time_idx) * mse + return jnp.mean(loss) + + +# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation +class SCMPE(SFMPE): + r"""Sequential consistency model posterior estimation. + + Implements a sequential version of the CMPE algorithm introduced in [1]_. + For all rounds $r > 1$ parameter samples + :math:`\theta \sim \hat{p}^r(\theta)` are drawn from + the approximate posterior instead of the prior when computing consistency + loss. Note that the implementation does not strictly follow the paper. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: a neural network + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration + + Examples: + >>> import distrax + >>> from sbijax import SCMPE + >>> from sbijax.nn import make_consistency_model + >>> + >>> prior = distrax.Normal(0.0, 1.0) + >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) + >>> fns = (prior.sample, prior.log_prob), s + >>> net = make_consistency_model(1) + >>> + >>> estim = SCMPE(fns, net) + + References: + .. [1] Schmitt, Marvin, et al. "Consistency Models for Scalable and + Fast Simulation-Based Inference". + arXiv preprint arXiv:2312.05440, 2023. + """ + + def __init__(self, model_fns, network, t_max=50.0, t_min=0.001): + """Construct a SCMPE object. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + network: network: a neural network + t_min: minimal time point for ODE integration + t_max: maximal time point for ODE integration + """ + super().__init__(model_fns, network) + self._t_min = t_min + self._t_max = t_max + + # pylint: disable=undefined-loop-variable + def _fit_model_single_round( + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **next(iter(train_iter))) + ema_params = params.copy() + state = optimizer.init(params) + + loss_fn = jax.jit( + partial( + _consistency_loss, + apply_fn=self.model.apply, + is_training=True, + t_max=self._t_max, + t_min=self._t_min, + ) + ) + + @jax.jit + def ema_update(params, avg_params): + return optax.incremental_update(avg_params, params, step_size=0.01) + + @jax.jit + def step(params, ema_params, rng, state, n_iter, **batch): + loss, grads = jax.value_and_grad(loss_fn)( + params, ema_params, rng, n_iter=n_iter, **batch + ) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + new_ema_params = ema_update(new_params, ema_params) + return loss, new_params, new_ema_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience * 2) + best_params, best_loss = None, np.inf + logging.info("training model") + for i in tqdm(range(n_iter)): + train_loss = 0.0 + rng_key = jr.fold_in(seed, i) + for batch in train_iter: + train_key, rng_key = jr.split(rng_key) + batch_loss, params, ema_params, state = step( + params, ema_params, train_key, state, n_iter + 1, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._validation_loss( + val_key, params, ema_params, n_iter, val_iter + ) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + return best_params, losses + + def _init_params(self, rng_key, **init_data): + times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1)) + params = self.model.init( + rng_key, + method="vector_field", + theta=init_data["theta"], + time=times, + context=init_data["y"], + is_training=True, + ) + return params + + # pylint: disable=arguments-differ + def _validation_loss(self, rng_key, params, ema_params, n_iter, val_iter): + loss_fn = jax.jit( + partial( + _consistency_loss, + apply_fn=self.model.apply, + is_training=False, + t_max=self._t_max, + t_min=self._t_min, + n_iter=n_iter, + ) + ) + + def body_fn(batch_key, **batch): + loss = loss_fn(params, ema_params, batch_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + loss = 0.0 + for batch in val_iter: + val_key, rng_key = jr.split(rng_key) + loss += body_fn(val_key, **batch) + return loss diff --git a/sbijax/_src/scmpe_test.py b/sbijax/_src/scmpe_test.py new file mode 100644 index 0000000..8f6a38e --- /dev/null +++ b/sbijax/_src/scmpe_test.py @@ -0,0 +1,60 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SCMPE +from sbijax.nn import make_consistency_model + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_scmpe(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SCMPE(fns, make_consistency_model(2)) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/_src/fmpe.py b/sbijax/_src/sfmpe.py similarity index 98% rename from sbijax/_src/fmpe.py rename to sbijax/_src/sfmpe.py index d426c68..440434d 100644 --- a/sbijax/_src/fmpe.py +++ b/sbijax/_src/sfmpe.py @@ -65,7 +65,8 @@ class SFMPE(SNE): For all rounds $r > 1$ parameter samples :math:`\theta \sim \hat{p}^r(\theta)` are drawn from the approximate posterior instead of the prior when computing the flow - matching loss. + matching loss. Note that the implementation does not strictly follow the + paper. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -93,7 +94,7 @@ class SFMPE(SNE): """ def __init__(self, model_fns, density_estimator): - """Construct a FMPE object. + """Construct a SFMPE object. Args: model_fns: a tuple of tuples. The first element is a tuple that @@ -261,6 +262,7 @@ def sample_posterior( sample_key, method="sample", context=jnp.tile(observable, [n_sim, 1]), + is_training=False, ) proposal_probs = self.prior_log_density_fn(proposal) proposal_accepted = proposal[jnp.isfinite(proposal_probs)] diff --git a/sbijax/_src/sfmpe_test.py b/sbijax/_src/sfmpe_test.py new file mode 100644 index 0000000..102ca07 --- /dev/null +++ b/sbijax/_src/sfmpe_test.py @@ -0,0 +1,60 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SFMPE +from sbijax.nn import make_ccnf + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_sfmpe(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SFMPE(fns, make_ccnf(2)) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index 635b6f9..1972b41 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -1,5 +1,9 @@ """Neural network module.""" +from sbijax._src.nn.consistency_model import ( + ConsistencyModel, + make_consistency_model, +) from sbijax._src.nn.continuous_normalizing_flow import CCNF, make_ccnf from sbijax._src.nn.make_flows import ( make_affine_maf,