diff --git a/examples/bivariate_gaussian_fmpe.py b/examples/bivariate_gaussian_fmpe.py new file mode 100644 index 0000000..668b239 --- /dev/null +++ b/examples/bivariate_gaussian_fmpe.py @@ -0,0 +1,77 @@ +""" +Example using flow matching posterior estimation on a bivariate Gaussian +""" + +import distrax +import haiku as hk +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr + +from sbijax import FMPE +from sbijax.nn import CCNF + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.Normal(jnp.zeros_like(theta), 1.0) + y = theta + p.sample(seed=seed) + return y + + +def make_model(dim): + @hk.transform + def _mlp(method, **kwargs): + def _nn(theta, time, context, **kwargs): + ins = jnp.concatenate([theta, time, context], axis=-1) + outs = hk.nets.MLP([64, 64, dim])(ins) + return outs + + ccnf = CCNF(dim, _nn) + return ccnf(method, **kwargs) + + return _mlp + + +def run(): + y_observed = jnp.array([2.0, -2.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = FMPE(fns, make_model(2)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = estim.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + ) + + rng_key = jr.PRNGKey(23) + post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(post_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 2bf3b76..833dd2f 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -6,6 +6,7 @@ from sbijax._src.abc.smc_abc import SMCABC +from sbijax._src.fmpe import FMPE from sbijax._src.snass import SNASS from sbijax._src.snasss import SNASSS from sbijax._src.snl import SNL diff --git a/sbijax/_src/fmpe.py b/sbijax/_src/fmpe.py new file mode 100644 index 0000000..f108e80 --- /dev/null +++ b/sbijax/_src/fmpe.py @@ -0,0 +1,272 @@ +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr +from tqdm import tqdm + +from sbijax._src._sne_base import SNE +from sbijax._src.nn.continuous_normalizing_flow import CCNF +from sbijax._src.util.early_stopping import EarlyStopping + + +def _sample_theta_t(rng_key, times, theta, sigma_min): + mus = times * theta + sigmata = 1.0 - (1.0 - sigma_min) * times + sigmata = sigmata.reshape(times.shape[0], 1) + + noise = jr.normal(rng_key, shape=(*theta.shape,)) + theta_t = noise * sigmata + mus + return theta_t + + +def _ut(theta_t, theta, times, sigma_min): + num = theta - (1.0 - sigma_min) * theta_t + denom = 1.0 - (1.0 - sigma_min) * times + return num / denom + + +# pylint: disable=too-many-locals +def _cfm_loss( + params, rng_key, apply_fn, sigma_min=0.001, is_training=True, **batch +): + theta = batch["theta"] + n, _ = theta.shape + + t_key, rng_key = jr.split(rng_key) + times = jr.uniform(t_key, shape=(n, 1)) + + theta_key, rng_key = jr.split(rng_key) + theta_t = _sample_theta_t(theta_key, times, theta, sigma_min) + + train_rng, rng_key = jr.split(rng_key) + vs = apply_fn( + params, + train_rng, + method="vector_field", + theta=theta_t, + time=times, + context=batch["y"], + is_training=is_training, + ) + uts = _ut(theta_t, theta, times, sigma_min) + + loss = jnp.mean(jnp.square(vs - uts)) + return loss + + +# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation +class FMPE(SNE): + """Flow matching posterior estimation. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a continuous normalizing flow model + + Examples: + >>> import distrax + >>> from sbijax import SNP + >>> from sbijax.nn import make_affine_maf + >>> + >>> prior = distrax.Normal(0.0, 1.0) + >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) + >>> fns = (prior.sample, prior.log_prob), s + >>> flow = make_affine_maf() + >>> + >>> snr = SNP(fns, flow) + + References: + .. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable + Simulation-Based Inference." Advances in Neural Information + Processing Systems, 2024. + """ + + def __init__(self, model_fns, density_estimator: CCNF): + """Construct a FMPE object. + + Args: + model_fns: a tuple of tuples. The first element is a tuple that + consists of functions to sample and evaluate the + log-probability of a data point. The second element is a + simulator function. + density_estimator: a (neural) conditional density estimator + to model the posterior distribution + """ + super().__init__(model_fns, density_estimator) + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + data, + *, + optimizer=optax.adam(0.0003), + n_iter=1000, + batch_size=100, + percentage_data_as_validation_set=0.1, + n_early_stopping_patience=10, + **kwargs, + ): + """Fit the model. + + Args: + rng_key: a jax random key + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated + data that is used for validation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation\ + + Returns: + a tuple of parameters and a tuple of the training information + """ + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return params, losses + + # pylint: disable=undefined-loop-variable + def _fit_model_single_round( + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **next(iter(train_iter))) + state = optimizer.init(params) + + loss_fn = jax.jit( + partial(_cfm_loss, apply_fn=self.model.apply, is_training=True) + ) + + @jax.jit + def step(params, rng, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + best_params, best_loss = None, np.inf + logging.info("training model") + for i in tqdm(range(n_iter)): + train_loss = 0.0 + rng_key = jr.fold_in(seed, i) + for batch in train_iter: + train_key, rng_key = jr.split(rng_key) + batch_loss, params, state = step( + params, train_key, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._validation_loss(val_key, params, val_iter) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + return best_params, losses + + def _init_params(self, rng_key, **init_data): + times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1)) + params = self.model.init( + rng_key, + method="vector_field", + theta=init_data["theta"], + time=times, + context=init_data["y"], + is_training=False, + ) + return params + + def _validation_loss(self, rng_key, params, val_iter): + loss_fn = jax.jit( + partial(_cfm_loss, apply_fn=self.model.apply, is_training=False) + ) + + def body_fn(batch_key, **batch): + loss = loss_fn(params, batch_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + loss = 0.0 + for batch in val_iter: + val_key, rng_key = jr.split(rng_key) + loss += body_fn(val_key, **batch) + return loss + + def sample_posterior( + self, rng_key, params, observable, *, n_samples=4_000, **kwargs + ): + r"""Sample from the approximate posterior. + + Args: + rng_key: a jax random key + params: a pytree of neural network parameters + observable: observation to condition on + n_samples: number of samples to draw + + Returns: + returns an array of samples from the posterior distribution of + dimension (n_samples \times p) + """ + observable = jnp.atleast_2d(observable) + + thetas = None + n_curr = n_samples + n_total_simulations_round = 0 + while n_curr > 0: + n_sim = jnp.minimum(200, jnp.maximum(200, n_curr)) + n_total_simulations_round += n_sim + sample_key, rng_key = jr.split(rng_key) + proposal = self.model.apply( + params, + sample_key, + method="sample", + context=jnp.tile(observable, [n_sim, 1]), + ) + proposal_probs = self.prior_log_density_fn(proposal) + proposal_accepted = proposal[jnp.isfinite(proposal_probs)] + if thetas is None: + thetas = proposal_accepted + else: + thetas = jnp.vstack([thetas, proposal_accepted]) + n_curr -= proposal_accepted.shape[0] + + self.n_total_simulations += n_total_simulations_round + return ( + thetas[:n_samples], + thetas.shape[0] / n_total_simulations_round, + ) diff --git a/sbijax/_src/nn/continuous_normalizing_flow.py b/sbijax/_src/nn/continuous_normalizing_flow.py new file mode 100644 index 0000000..1c405e0 --- /dev/null +++ b/sbijax/_src/nn/continuous_normalizing_flow.py @@ -0,0 +1,223 @@ +from typing import Callable + +import distrax +import haiku as hk +import jax +from jax import numpy as jnp +from jax.nn import glu +from scipy import integrate + +__all__ = ["CCNF", "make_ccnf"] + + +class CCNF(hk.Module): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + + def __init__(self, n_dimension: int, transform: Callable): + """Conditional continuous normalizing flow. + + Args: + n_dimension: the dimensionality of the modelled space + transform: a haiku module. The transform is a callable that has to + take as input arguments named 'theta', 'time', 'context' and + **kwargs. Theta, time and context are two-dimensional arrays + with the same batch dimensions. + """ + super().__init__() + self._n_dimension = n_dimension + self._network = transform + self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) + + def __call__(self, method, **kwargs): + """Aplpy the flow. + + Args: + method (str): method to call + + Keyword Args: + keyword arguments for the called method: + """ + return getattr(self, method)(**kwargs) + + def sample(self, context): + """Sample from the pushforward. + + Args: + context: array of conditioning variables + """ + theta_0 = self._base_distribution.sample( + seed=hk.next_rng_key(), sample_shape=(context.shape[0],) + ) + + def ode_func(time, theta_t): + theta_t = theta_t.reshape(-1, self._n_dimension) + time = jnp.full((theta_t.shape[0], 1), time) + ret = self.vector_field( + theta=theta_t, time=time, context=context, is_training=False + ) + return ret.reshape(-1) + + res = integrate.solve_ivp( + ode_func, + (0.0, 1.0), + theta_0.reshape(-1), + rtol=1e-5, + atol=1e-5, + method="RK45", + ) + + ret = res.y[:, -1].reshape(-1, self._n_dimension) + return ret + + def vector_field(self, theta, time, context, **kwargs): + """Compute the vector field. + + Args: + theta: array of parameters + time: time variables + context: array of conditioning variables + + Keyword Args: + keyword arguments that aer passed tothe neural network + """ + time = jnp.full((theta.shape[0], 1), time) + return self._network(theta=theta, time=time, context=context, **kwargs) + + +# pylint: disable=too-many-arguments +class _ResnetBlock(hk.Module): + """A block for a 1d residual network.""" + + def __init__( + self, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.1, + ): + super().__init__() + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + + def __call__(self, inputs, context, is_training=False): + outputs = inputs + if self.do_batch_norm: + outputs = hk.BatchNorm(True, True, self.batch_norm_decay)( + outputs, is_training=is_training + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + if is_training: + outputs = hk.dropout( + rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs + ) + outputs = hk.Linear(self.hidden_size)(outputs) + context_proj = hk.Linear(inputs.shape[-1])(context) + outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1)) + return outputs + inputs + + +# pylint: disable=too-many-arguments +class _CCNFResnet(hk.Module): + """A simplified 1-d residual network.""" + + def __init__( + self, + n_layers: int, + n_dimension: int, + hidden_size: int, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.2, + do_batch_norm: bool = True, + batch_norm_decay: float = 0.1, + ): + super().__init__() + self.n_layers = n_layers + self.n_dimension = n_dimension + self.hidden_size = hidden_size + self.activation = activation + self.do_batch_norm = do_batch_norm + self.dropout_rate = dropout_rate + self.batch_norm_decay = batch_norm_decay + + def __call__(self, theta, time, context, is_training=False, **kwargs): + outputs = context + # this is a bit weird, but what the paper suggests: + # instead of using times and context (i.e., y) as conditioning variables + # it suggests using times and theta and use y in the resnet blocks, + # since theta is typically low-dim and y is typically high-dime + t_theta_embedding = jnp.concatenate( + [ + hk.Linear(self.n_dimension)(theta), + hk.Linear(self.n_dimension)(time), + ], + axis=-1, + ) + outputs = hk.Linear(self.hidden_size)(outputs) + outputs = self.activation(outputs) + for _ in range(self.n_layers): + outputs = _ResnetBlock( + hidden_size=self.hidden_size, + activation=self.activation, + dropout_rate=self.dropout_rate, + do_batch_norm=self.do_batch_norm, + batch_norm_decay=self.batch_norm_decay, + )(outputs, context=t_theta_embedding, is_training=is_training) + outputs = self.activation(outputs) + outputs = hk.Linear(self.n_dimension)(outputs) + return outputs + + +def make_ccnf( + n_dimension: int, + n_layers: int = 2, + hidden_size: int = 64, + activation: Callable = jax.nn.tanh, + dropout_rate: float = 0.2, + do_batch_norm: bool = False, + batch_norm_decay: float = 0.2, +): + """Create a conditional continuous normalizing flow. + + The CCNF uses a residual network as transformer which is created + automatically. + + Args: + n_dimension: dimensionality of modelled space + n_layers: number of resnet blocks + hidden_size: sizes of hidden layers for each resnet block + activation: a jax activation function + dropout_rate: dropout rate to use in resnet blocks + do_batch_norm: use batch normalization or not + batch_norm_decay: decay rate of EMA in batch norm layer + Returns: + returns a conditional continuous normalizing flow + """ + + @hk.transform + def _flow(method, **kwargs): + nn = _CCNFResnet( + n_layers=n_layers, + n_dimension=n_dimension, + hidden_size=hidden_size, + activation=activation, + do_batch_norm=do_batch_norm, + dropout_rate=dropout_rate, + batch_norm_decay=batch_norm_decay, + ) + ccnf = CCNF(n_dimension, nn) + return ccnf(method, **kwargs) + + return _flow diff --git a/sbijax/_src/snass_test.py b/sbijax/_src/snass_test.py new file mode 100644 index 0000000..c4debc6 --- /dev/null +++ b/sbijax/_src/snass_test.py @@ -0,0 +1,62 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SNASS +from sbijax.nn import make_affine_maf, make_snass_net + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_snass(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SNASS( + fns, make_affine_maf(1, 2, (32, 32)), make_snass_net((32, 1), (32, 1)) + ) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/_src/snasss_test.py b/sbijax/_src/snasss_test.py new file mode 100644 index 0000000..b00e646 --- /dev/null +++ b/sbijax/_src/snasss_test.py @@ -0,0 +1,66 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SNASSS +from sbijax._src.nn.make_snass_networks import make_snasss_net +from sbijax.nn import make_affine_maf + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + y = jnp.repeat(y, 5, axis=1) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_snasss(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.repeat(jnp.array([-1.0, 1.0]).reshape(-1, 2), 5, axis=1) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SNASSS( + fns, + make_affine_maf(5, 2, (32, 32)), + make_snasss_net((32, 5), (32, 1), (32, 1)), + ) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/_src/snl_test.py b/sbijax/_src/snl_test.py index efacd66..e8fd056 100644 --- a/sbijax/_src/snl_test.py +++ b/sbijax/_src/snl_test.py @@ -80,7 +80,7 @@ def test_snl(): n_samples=200, n_warmup=100, ) - params, info = snl.fit(next(rng_seq), data=data) + params, info = snl.fit(next(rng_seq), data=data, n_iter=2) _ = snl.sample_posterior( next(rng_seq), params, diff --git a/sbijax/_src/snp_test.py b/sbijax/_src/snp_test.py index ace28c0..cf39560 100644 --- a/sbijax/_src/snp_test.py +++ b/sbijax/_src/snp_test.py @@ -78,7 +78,7 @@ def test_snp(): n_samples=200, n_warmup=100, ) - params, info = snp.fit(next(rng_seq), data=data) + params, info = snp.fit(next(rng_seq), data=data, n_iter=2) _ = snp.sample_posterior( next(rng_seq), params, diff --git a/sbijax/_src/snr.py b/sbijax/_src/snr.py index bf1dc23..7a408e1 100644 --- a/sbijax/_src/snr.py +++ b/sbijax/_src/snr.py @@ -7,7 +7,7 @@ import numpy as np import optax from absl import logging -from haiku import Params +from haiku import Params, Transformed from jax import Array from jax import numpy as jnp from jax import random as jr @@ -119,7 +119,7 @@ class SNR(SNE): def __init__( self, model_fns: Tuple[Tuple[Callable, Callable], Callable], - classifier: Callable, + classifier: Transformed, num_classes: int = 10, gamma: float = 1.0, ): diff --git a/sbijax/_src/snr_test.py b/sbijax/_src/snr_test.py new file mode 100644 index 0000000..d215e3a --- /dev/null +++ b/sbijax/_src/snr_test.py @@ -0,0 +1,68 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SNR + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def make_model(): + @hk.without_apply_rng + @hk.transform + def _mlp(inputs, **kwargs): + return hk.nets.MLP([64, 64, 1])(inputs) + + return _mlp + + +def test_snp(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SNR(fns, make_model()) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index c9082f7..635b6f9 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -1,5 +1,6 @@ """Neural network module.""" +from sbijax._src.nn.continuous_normalizing_flow import CCNF, make_ccnf from sbijax._src.nn.make_flows import ( make_affine_maf, make_surjective_affine_maf,