Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistency model posterior estimation #27

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
- [Contrastive Neural Ratio Estimation](https://arxiv.org/abs/2210.06170) (short `SNR`)
- [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`)
- [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`)
- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`)
- [Flow matching posterior estimation](https://arxiv.org/abs/2305.17161) (`SFMPE`)
- [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`)

where the acronyms in parentheses denote the names of the methods in `sbijax`.

Expand Down Expand Up @@ -53,6 +54,7 @@ pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

## Acknowledgements

> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more
feature-complete and user-friendly, and better documented.

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- `Neural Approximate Sufficient Statistics <https://arxiv.org/abs/2010.10079>`_ (:code:`SNASS`)
- `Neural Approximate Slice Sufficient Statistics <https://openreview.net/forum?id=jjzJ768iV1>`_ (:code:`SNASSS`)
- `Flow matching posterior estimation <https://arxiv.org/abs/2305.17161>`_ (:code:`SFMPE`)
- `Consistency model posterior estimation <https://arxiv.org/abs/2312.05440>`_ (:code:`SCMPE`)

.. caution::

Expand Down
7 changes: 7 additions & 0 deletions docs/sbijax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Methods
SNP
SNR
SFMPE
SCMPE
SNASS
SNASSS

Expand Down Expand Up @@ -49,6 +50,12 @@ SFMPE
.. autoclass:: SFMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SCMPE
~~~~~

.. autoclass:: SCMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SNASS
~~~~~

Expand Down
85 changes: 85 additions & 0 deletions examples/bivariate_gaussian_cfmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Example using consistency model posterior estimation on a bivariate Gaussian
"""

import distrax
import haiku as hk
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import SCMPE
from sbijax.nn import ConsistencyModel


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def make_model(dim):
@hk.transform
def _mlp(method, **kwargs):
def _c_skip(time):
return 1 / ((time - 0.001) ** 2 + 1)

def _c_out(time):
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time**2)

def _nn(theta, time, context, **kwargs):
ins = jnp.concatenate([theta, time, context], axis=-1)
outs = hk.nets.MLP([64, 64, dim])(ins)
out_skip = _c_skip(time) * theta + _c_out(time) * outs
return out_skip

cm = ConsistencyModel(dim, _nn)
return cm(method, **kwargs)

return _mlp


def run():
y_observed = jnp.array([2.0, -2.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SCMPE(fns, make_model(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
)

rng_key = jr.PRNGKey(23)
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
print(post_samples)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(post_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
5 changes: 3 additions & 2 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.9"
__version__ = "0.2.0"


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.fmpe import SFMPE
from sbijax._src.scmpe import SCMPE
from sbijax._src.sfmpe import SFMPE
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
Expand Down
212 changes: 212 additions & 0 deletions sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Callable

import distrax
import haiku as hk
import jax
from jax import numpy as jnp

from sbijax._src.nn.continuous_normalizing_flow import _ResnetBlock

__all__ = ["ConsistencyModel", "make_consistency_model"]


class ConsistencyModel(hk.Module):
"""A consistency model.

Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
"""

def __init__(
self,
n_dimension: int,
transform: Callable,
t_min: float = 0.001,
t_max: float = 50.0,
):
"""Construct a consistency model.

Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
"""
super().__init__()
self._n_dimension = n_dimension
self._network = transform
self._t_max = t_max
self._t_min = t_min
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0)

def __call__(self, method, **kwargs):
"""Aplpy the flow.

Args:
method (str): method to call

Keyword Args:
keyword arguments for the called method:
"""
return getattr(self, method)(**kwargs)

def sample(self, context, **kwargs):
"""Sample from the consistency model.

Args:
context: array of conditioning variables
kwargs: keyword argumente like 'is_training'
"""
noise = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(context.shape[0],)
)
y_hat = self.vector_field(noise, self._t_max, context, **kwargs)

noise = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(y_hat.shape[0],)
)
tme = self._t_min + (self._t_max - self._t_min) / 2
noise = jnp.sqrt(jnp.square(tme) - jnp.square(self._t_min)) * noise
y_tme = y_hat + noise
y_hat = self.vector_field(y_tme, tme, context, **kwargs)

return y_hat

def vector_field(self, theta, time, context, **kwargs):
"""Compute the vector field.

Args:
theta: array of parameters
time: time variables
context: array of conditioning variables

Keyword Args:
keyword arguments that aer passed tothe neural network
"""
time = jnp.full((theta.shape[0], 1), time)
return self._network(theta=theta, time=time, context=context, **kwargs)


# pylint: disable=too-many-arguments,too-many-instance-attributes
class _CMResnet(hk.Module):
"""A simplified 1-d residual network."""

def __init__(
self,
n_layers: int,
n_dimension: int,
hidden_size: int,
activation: Callable = jax.nn.relu,
dropout_rate: float = 0.0,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.1,
t_min: float = 0.001,
sigma_data: float = 1.0,
):
super().__init__()
self.n_layers = n_layers
self.n_dimension = n_dimension
self.hidden_size = hidden_size
self.activation = activation
self.do_batch_norm = do_batch_norm
self.dropout_rate = dropout_rate
self.batch_norm_decay = batch_norm_decay
self.sigma_data = sigma_data
self.var_data = self.sigma_data**2
self.t_min = t_min

def __call__(self, theta, time, context, is_training, **kwargs):
outputs = context
t_theta_embedding = jnp.concatenate(
[
hk.Linear(self.n_dimension)(theta),
hk.Linear(self.n_dimension)(time),
],
axis=-1,
)
outputs = hk.Linear(self.hidden_size)(outputs)
outputs = self.activation(outputs)
for _ in range(self.n_layers):
outputs = _ResnetBlock(
hidden_size=self.hidden_size,
activation=self.activation,
dropout_rate=self.dropout_rate,
do_batch_norm=self.do_batch_norm,
batch_norm_decay=self.batch_norm_decay,
)(outputs, context=t_theta_embedding, is_training=is_training)
outputs = self.activation(outputs)
outputs = hk.Linear(self.n_dimension)(outputs)

# TODO(simon): dan we choose sigma automatically?
out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs
return out_skip

def _c_skip(self, time):
return self.var_data / ((time - self.t_min) ** 2 + self.var_data)

def _c_out(self, time):
return (
self.sigma_data
* (time - self.t_min)
/ jnp.sqrt(self.var_data + time**2)
)


def make_consistency_model(
n_dimension: int,
n_layers: int = 2,
hidden_size: int = 64,
activation: Callable = jax.nn.tanh,
dropout_rate: float = 0.2,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.2,
t_min: float = 0.001,
t_max: float = 50.0,
sigma_data: float = 1.0,
):
"""Create a consistency model.

The consistency model uses a residual network as score network.

Args:
n_dimension: dimensionality of modelled space
n_layers: number of resnet blocks
hidden_size: sizes of hidden layers for each resnet block
activation: a jax activation function
dropout_rate: dropout rate to use in resnet blocks
do_batch_norm: use batch normalization or not
batch_norm_decay: decay rate of EMA in batch norm layer
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
sigma_data: the standard deviation of the data :)

Returns:
returns a consistency model
"""

@hk.transform
def _cm(method, **kwargs):
nn = _CMResnet(
n_layers=n_layers,
n_dimension=n_dimension,
hidden_size=hidden_size,
activation=activation,
do_batch_norm=do_batch_norm,
dropout_rate=dropout_rate,
batch_norm_decay=batch_norm_decay,
t_min=t_min,
sigma_data=sigma_data,
)
cm = ConsistencyModel(n_dimension, nn, t_min=t_min, t_max=t_max)
return cm(method, **kwargs)

return _cm
4 changes: 2 additions & 2 deletions sbijax/_src/nn/continuous_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __call__(self, method, **kwargs):
"""
return getattr(self, method)(**kwargs)

def sample(self, context):
def sample(self, context, **kwargs):
"""Sample from the pushforward.

Args:
Expand All @@ -61,7 +61,7 @@ def ode_func(time, theta_t):
theta_t = theta_t.reshape(-1, self._n_dimension)
time = jnp.full((theta_t.shape[0], 1), time)
ret = self.vector_field(
theta=theta_t, time=time, context=context, is_training=False
theta=theta_t, time=time, context=context, **kwargs
)
return ret.reshape(-1)

Expand Down
Loading
Loading