Skip to content

Commit

Permalink
Add CMPE (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 29, 2024
1 parent 9017ab2 commit 513e8fd
Show file tree
Hide file tree
Showing 12 changed files with 688 additions and 7 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
- [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`)
- [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`)
- [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`)
- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`)
- [Flow matching posterior estimation](https://arxiv.org/abs/2305.17161) (`SFMPE`)
- [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`)

where the acronyms in parentheses denote the names of the methods in `sbijax`.

Expand Down Expand Up @@ -53,6 +54,7 @@ pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

## Acknowledgements

> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more
feature-complete and user-friendly, and better documented.

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- `Neural Approximate Sufficient Statistics <https://arxiv.org/abs/2010.10079>`_ (:code:`SNASS`)
- `Neural Approximate Slice Sufficient Statistics <https://openreview.net/forum?id=jjzJ768iV1>`_ (:code:`SNASSS`)
- `Flow matching posterior estimation <https://arxiv.org/abs/2305.17161>`_ (:code:`SFMPE`)
- `Consistency model posterior estimation <https://arxiv.org/abs/2312.05440>`_ (:code:`SCMPE`)

.. caution::

Expand Down
7 changes: 7 additions & 0 deletions docs/sbijax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Methods
SNP
SNR
SFMPE
SCMPE
SNASS
SNASSS

Expand Down Expand Up @@ -49,6 +50,12 @@ SFMPE
.. autoclass:: SFMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SCMPE
~~~~~

.. autoclass:: SCMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SNASS
~~~~~

Expand Down
85 changes: 85 additions & 0 deletions examples/bivariate_gaussian_cfmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Example using consistency model posterior estimation on a bivariate Gaussian
"""

import distrax
import haiku as hk
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import SCMPE
from sbijax.nn import ConsistencyModel


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def make_model(dim):
@hk.transform
def _mlp(method, **kwargs):
def _c_skip(time):
return 1 / ((time - 0.001) ** 2 + 1)

def _c_out(time):
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time**2)

def _nn(theta, time, context, **kwargs):
ins = jnp.concatenate([theta, time, context], axis=-1)
outs = hk.nets.MLP([64, 64, dim])(ins)
out_skip = _c_skip(time) * theta + _c_out(time) * outs
return out_skip

cm = ConsistencyModel(dim, _nn)
return cm(method, **kwargs)

return _mlp


def run():
y_observed = jnp.array([2.0, -2.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SCMPE(fns, make_model(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
)

rng_key = jr.PRNGKey(23)
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
print(post_samples)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(post_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
5 changes: 3 additions & 2 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.9"
__version__ = "0.2.0"


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.fmpe import SFMPE
from sbijax._src.scmpe import SCMPE
from sbijax._src.sfmpe import SFMPE
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
Expand Down
212 changes: 212 additions & 0 deletions sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Callable

import distrax
import haiku as hk
import jax
from jax import numpy as jnp

from sbijax._src.nn.continuous_normalizing_flow import _ResnetBlock

__all__ = ["ConsistencyModel", "make_consistency_model"]


class ConsistencyModel(hk.Module):
"""A consistency model.
Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
"""

def __init__(
self,
n_dimension: int,
transform: Callable,
t_min: float = 0.001,
t_max: float = 50.0,
):
"""Construct a consistency model.
Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
"""
super().__init__()
self._n_dimension = n_dimension
self._network = transform
self._t_max = t_max
self._t_min = t_min
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0)

def __call__(self, method, **kwargs):
"""Aplpy the flow.
Args:
method (str): method to call
Keyword Args:
keyword arguments for the called method:
"""
return getattr(self, method)(**kwargs)

def sample(self, context, **kwargs):
"""Sample from the consistency model.
Args:
context: array of conditioning variables
kwargs: keyword argumente like 'is_training'
"""
noise = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(context.shape[0],)
)
y_hat = self.vector_field(noise, self._t_max, context, **kwargs)

noise = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(y_hat.shape[0],)
)
tme = self._t_min + (self._t_max - self._t_min) / 2
noise = jnp.sqrt(jnp.square(tme) - jnp.square(self._t_min)) * noise
y_tme = y_hat + noise
y_hat = self.vector_field(y_tme, tme, context, **kwargs)

return y_hat

def vector_field(self, theta, time, context, **kwargs):
"""Compute the vector field.
Args:
theta: array of parameters
time: time variables
context: array of conditioning variables
Keyword Args:
keyword arguments that aer passed tothe neural network
"""
time = jnp.full((theta.shape[0], 1), time)
return self._network(theta=theta, time=time, context=context, **kwargs)


# pylint: disable=too-many-arguments,too-many-instance-attributes
class _CMResnet(hk.Module):
"""A simplified 1-d residual network."""

def __init__(
self,
n_layers: int,
n_dimension: int,
hidden_size: int,
activation: Callable = jax.nn.relu,
dropout_rate: float = 0.0,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.1,
t_min: float = 0.001,
sigma_data: float = 1.0,
):
super().__init__()
self.n_layers = n_layers
self.n_dimension = n_dimension
self.hidden_size = hidden_size
self.activation = activation
self.do_batch_norm = do_batch_norm
self.dropout_rate = dropout_rate
self.batch_norm_decay = batch_norm_decay
self.sigma_data = sigma_data
self.var_data = self.sigma_data**2
self.t_min = t_min

def __call__(self, theta, time, context, is_training, **kwargs):
outputs = context
t_theta_embedding = jnp.concatenate(
[
hk.Linear(self.n_dimension)(theta),
hk.Linear(self.n_dimension)(time),
],
axis=-1,
)
outputs = hk.Linear(self.hidden_size)(outputs)
outputs = self.activation(outputs)
for _ in range(self.n_layers):
outputs = _ResnetBlock(
hidden_size=self.hidden_size,
activation=self.activation,
dropout_rate=self.dropout_rate,
do_batch_norm=self.do_batch_norm,
batch_norm_decay=self.batch_norm_decay,
)(outputs, context=t_theta_embedding, is_training=is_training)
outputs = self.activation(outputs)
outputs = hk.Linear(self.n_dimension)(outputs)

# TODO(simon): dan we choose sigma automatically?
out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs
return out_skip

def _c_skip(self, time):
return self.var_data / ((time - self.t_min) ** 2 + self.var_data)

def _c_out(self, time):
return (
self.sigma_data
* (time - self.t_min)
/ jnp.sqrt(self.var_data + time**2)
)


def make_consistency_model(
n_dimension: int,
n_layers: int = 2,
hidden_size: int = 64,
activation: Callable = jax.nn.tanh,
dropout_rate: float = 0.2,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.2,
t_min: float = 0.001,
t_max: float = 50.0,
sigma_data: float = 1.0,
):
"""Create a consistency model.
The consistency model uses a residual network as score network.
Args:
n_dimension: dimensionality of modelled space
n_layers: number of resnet blocks
hidden_size: sizes of hidden layers for each resnet block
activation: a jax activation function
dropout_rate: dropout rate to use in resnet blocks
do_batch_norm: use batch normalization or not
batch_norm_decay: decay rate of EMA in batch norm layer
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
sigma_data: the standard deviation of the data :)
Returns:
returns a consistency model
"""

@hk.transform
def _cm(method, **kwargs):
nn = _CMResnet(
n_layers=n_layers,
n_dimension=n_dimension,
hidden_size=hidden_size,
activation=activation,
do_batch_norm=do_batch_norm,
dropout_rate=dropout_rate,
batch_norm_decay=batch_norm_decay,
t_min=t_min,
sigma_data=sigma_data,
)
cm = ConsistencyModel(n_dimension, nn, t_min=t_min, t_max=t_max)
return cm(method, **kwargs)

return _cm
4 changes: 2 additions & 2 deletions sbijax/_src/nn/continuous_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __call__(self, method, **kwargs):
"""
return getattr(self, method)(**kwargs)

def sample(self, context):
def sample(self, context, **kwargs):
"""Sample from the pushforward.
Args:
Expand All @@ -61,7 +61,7 @@ def ode_func(time, theta_t):
theta_t = theta_t.reshape(-1, self._n_dimension)
time = jnp.full((theta_t.shape[0], 1), time)
ret = self.vector_field(
theta=theta_t, time=time, context=context, is_training=False
theta=theta_t, time=time, context=context, **kwargs
)
return ret.reshape(-1)

Expand Down
Loading

0 comments on commit 513e8fd

Please sign in to comment.