Skip to content

Commit

Permalink
Add FMPE (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 28, 2024
1 parent db2a588 commit 78e8097
Show file tree
Hide file tree
Showing 11 changed files with 774 additions and 4 deletions.
77 changes: 77 additions & 0 deletions examples/bivariate_gaussian_fmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Example using flow matching posterior estimation on a bivariate Gaussian
"""

import distrax
import haiku as hk
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import FMPE
from sbijax.nn import CCNF


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def make_model(dim):
@hk.transform
def _mlp(method, **kwargs):
def _nn(theta, time, context, **kwargs):
ins = jnp.concatenate([theta, time, context], axis=-1)
outs = hk.nets.MLP([64, 64, dim])(ins)
return outs

ccnf = CCNF(dim, _nn)
return ccnf(method, **kwargs)

return _mlp


def run():
y_observed = jnp.array([2.0, -2.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = FMPE(fns, make_model(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
)

rng_key = jr.PRNGKey(23)
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(post_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
1 change: 1 addition & 0 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.fmpe import FMPE
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
Expand Down
272 changes: 272 additions & 0 deletions sbijax/_src/fmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
from functools import partial

import jax
import numpy as np
import optax
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from tqdm import tqdm

from sbijax._src._sne_base import SNE
from sbijax._src.nn.continuous_normalizing_flow import CCNF
from sbijax._src.util.early_stopping import EarlyStopping


def _sample_theta_t(rng_key, times, theta, sigma_min):
mus = times * theta
sigmata = 1.0 - (1.0 - sigma_min) * times
sigmata = sigmata.reshape(times.shape[0], 1)

noise = jr.normal(rng_key, shape=(*theta.shape,))
theta_t = noise * sigmata + mus
return theta_t


def _ut(theta_t, theta, times, sigma_min):
num = theta - (1.0 - sigma_min) * theta_t
denom = 1.0 - (1.0 - sigma_min) * times
return num / denom


# pylint: disable=too-many-locals
def _cfm_loss(
params, rng_key, apply_fn, sigma_min=0.001, is_training=True, **batch
):
theta = batch["theta"]
n, _ = theta.shape

t_key, rng_key = jr.split(rng_key)
times = jr.uniform(t_key, shape=(n, 1))

theta_key, rng_key = jr.split(rng_key)
theta_t = _sample_theta_t(theta_key, times, theta, sigma_min)

train_rng, rng_key = jr.split(rng_key)
vs = apply_fn(
params,
train_rng,
method="vector_field",
theta=theta_t,
time=times,
context=batch["y"],
is_training=is_training,
)
uts = _ut(theta_t, theta, times, sigma_min)

loss = jnp.mean(jnp.square(vs - uts))
return loss


# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
class FMPE(SNE):
"""Flow matching posterior estimation.
Args:
model_fns: a tuple of tuples. The first element is a tuple that
consists of functions to sample and evaluate the
log-probability of a data point. The second element is a
simulator function.
density_estimator: a continuous normalizing flow model
Examples:
>>> import distrax
>>> from sbijax import SNP
>>> from sbijax.nn import make_affine_maf
>>>
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> flow = make_affine_maf()
>>>
>>> snr = SNP(fns, flow)
References:
.. [1] Wildberger, Jonas, et al. "Flow Matching for Scalable
Simulation-Based Inference." Advances in Neural Information
Processing Systems, 2024.
"""

def __init__(self, model_fns, density_estimator: CCNF):
"""Construct a FMPE object.
Args:
model_fns: a tuple of tuples. The first element is a tuple that
consists of functions to sample and evaluate the
log-probability of a data point. The second element is a
simulator function.
density_estimator: a (neural) conditional density estimator
to model the posterior distribution
"""
super().__init__(model_fns, density_estimator)

# pylint: disable=arguments-differ,too-many-locals
def fit(
self,
rng_key,
data,
*,
optimizer=optax.adam(0.0003),
n_iter=1000,
batch_size=100,
percentage_data_as_validation_set=0.1,
n_early_stopping_patience=10,
**kwargs,
):
"""Fit the model.
Args:
rng_key: a jax random key
data: data set obtained from calling
`simulate_data_and_possibly_append`
optimizer: an optax optimizer object
n_iter: maximal number of training iterations per round
batch_size: batch size used for training the model
percentage_data_as_validation_set: percentage of the simulated
data that is used for validation and early stopping
n_early_stopping_patience: number of iterations of no improvement
of training the flow before stopping optimisation\
Returns:
a tuple of parameters and a tuple of the training information
"""
itr_key, rng_key = jr.split(rng_key)
train_iter, val_iter = self.as_iterators(
itr_key, data, batch_size, percentage_data_as_validation_set
)
params, losses = self._fit_model_single_round(
seed=rng_key,
train_iter=train_iter,
val_iter=val_iter,
optimizer=optimizer,
n_iter=n_iter,
n_early_stopping_patience=n_early_stopping_patience,
)

return params, losses

# pylint: disable=undefined-loop-variable
def _fit_model_single_round(
self,
seed,
train_iter,
val_iter,
optimizer,
n_iter,
n_early_stopping_patience,
):
init_key, seed = jr.split(seed)
params = self._init_params(init_key, **next(iter(train_iter)))
state = optimizer.init(params)

loss_fn = jax.jit(
partial(_cfm_loss, apply_fn=self.model.apply, is_training=True)
)

@jax.jit
def step(params, rng, state, **batch):
loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state

losses = np.zeros([n_iter, 2])
early_stop = EarlyStopping(1e-3, n_early_stopping_patience)
best_params, best_loss = None, np.inf
logging.info("training model")
for i in tqdm(range(n_iter)):
train_loss = 0.0
rng_key = jr.fold_in(seed, i)
for batch in train_iter:
train_key, rng_key = jr.split(rng_key)
batch_loss, params, state = step(
params, train_key, state, **batch
)
train_loss += batch_loss * (
batch["y"].shape[0] / train_iter.num_samples
)
val_key, rng_key = jr.split(rng_key)
validation_loss = self._validation_loss(val_key, params, val_iter)
losses[i] = jnp.array([train_loss, validation_loss])

_, early_stop = early_stop.update(validation_loss)
if early_stop.should_stop:
logging.info("early stopping criterion found")
break
if validation_loss < best_loss:
best_loss = validation_loss
best_params = params.copy()

losses = jnp.vstack(losses)[: (i + 1), :]
return best_params, losses

def _init_params(self, rng_key, **init_data):
times = jr.uniform(jr.PRNGKey(0), shape=(init_data["y"].shape[0], 1))
params = self.model.init(
rng_key,
method="vector_field",
theta=init_data["theta"],
time=times,
context=init_data["y"],
is_training=False,
)
return params

def _validation_loss(self, rng_key, params, val_iter):
loss_fn = jax.jit(
partial(_cfm_loss, apply_fn=self.model.apply, is_training=False)
)

def body_fn(batch_key, **batch):
loss = loss_fn(params, batch_key, **batch)
return loss * (batch["y"].shape[0] / val_iter.num_samples)

loss = 0.0
for batch in val_iter:
val_key, rng_key = jr.split(rng_key)
loss += body_fn(val_key, **batch)
return loss

def sample_posterior(
self, rng_key, params, observable, *, n_samples=4_000, **kwargs
):
r"""Sample from the approximate posterior.
Args:
rng_key: a jax random key
params: a pytree of neural network parameters
observable: observation to condition on
n_samples: number of samples to draw
Returns:
returns an array of samples from the posterior distribution of
dimension (n_samples \times p)
"""
observable = jnp.atleast_2d(observable)

thetas = None
n_curr = n_samples
n_total_simulations_round = 0
while n_curr > 0:
n_sim = jnp.minimum(200, jnp.maximum(200, n_curr))
n_total_simulations_round += n_sim
sample_key, rng_key = jr.split(rng_key)
proposal = self.model.apply(
params,
sample_key,
method="sample",
context=jnp.tile(observable, [n_sim, 1]),
)
proposal_probs = self.prior_log_density_fn(proposal)
proposal_accepted = proposal[jnp.isfinite(proposal_probs)]
if thetas is None:
thetas = proposal_accepted
else:
thetas = jnp.vstack([thetas, proposal_accepted])
n_curr -= proposal_accepted.shape[0]

self.n_total_simulations += n_total_simulations_round
return (
thetas[:n_samples],
thetas.shape[0] / n_total_simulations_round,
)
Loading

0 comments on commit 78e8097

Please sign in to comment.