Skip to content

Commit

Permalink
Impl consistency model posterior estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 28, 2024
1 parent 9017ab2 commit 41dd5b0
Show file tree
Hide file tree
Showing 8 changed files with 623 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

## Acknowledgements

> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package which is substantially more
feature-complete and user-friendly, and better documented.

Expand Down
7 changes: 7 additions & 0 deletions docs/sbijax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Methods
SNP
SNR
SFMPE
SCMPE
SNASS
SNASSS

Expand Down Expand Up @@ -46,6 +47,12 @@ SNR
SFMPE
~~~~~

.. autoclass:: SFMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SCMPE
~~~~~

.. autoclass:: SFMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

Expand Down
85 changes: 85 additions & 0 deletions examples/bivariate_gaussian_cfmpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Example using consistency model posterior estimation on a bivariate Gaussian
"""

import distrax
import haiku as hk
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import SCMPE
from sbijax.nn import ConsistencyModel


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def make_model(dim):
@hk.transform
def _mlp(method, **kwargs):
def _c_skip(time):
return 1 / ((time - 0.001) ** 2 + 1)

def _c_out(time):
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time ** 2)
def _nn(theta, time, context, **kwargs):
ins = jnp.concatenate([theta, time, context], axis=-1)
outs = hk.nets.MLP([64, 64, dim])(ins)
out_skip = _c_skip(time) * theta + _c_out(time) * outs
return out_skip

cm = ConsistencyModel(dim, _nn)
return cm(method, **kwargs)

return _mlp


def run():
y_observed = jnp.array([2.0, -2.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SCMPE(fns, make_model(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
)


rng_key = jr.PRNGKey(23)
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
print(post_samples)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(post_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
5 changes: 3 additions & 2 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.9"
__version__ = "0.2.0"


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.fmpe import SFMPE
from sbijax._src.scmpe import SCMPE
from sbijax._src.sfmpe import SFMPE
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
Expand Down
229 changes: 229 additions & 0 deletions sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
from typing import Callable

import distrax
import haiku as hk
import jax
from jax import numpy as jnp
from jax.nn import glu
from scipy import integrate

__all__ = ["ConsistencyModel", "make_consistency_model"]

from sbijax._src.nn.make_resnet import _Resnet


class ConsistencyModel(hk.Module):
"""Conditional continuous normalizing flow.
Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
"""

def __init__(self, n_dimension: int, transform: Callable, t_max=50):
"""Conditional continuous normalizing flow.
Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
"""
super().__init__()
self._n_dimension = n_dimension
self._network = transform
self._t_max = t_max
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0)

def __call__(self, method, **kwargs):
"""Aplpy the flow.
Args:
method (str): method to call
Keyword Args:
keyword arguments for the called method:
"""
return getattr(self, method)(**kwargs)

def sample(self, context):
"""Sample from the pushforward.
Args:
context: array of conditioning variables
"""
theta_0 = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(context.shape[0],)
)
y_hat = self.vector_field(theta_0, self._t_max, context)
return y_hat

def vector_field(self, theta, time, context, **kwargs):
"""Compute the vector field.
Args:
theta: array of parameters
time: time variables
context: array of conditioning variables
Keyword Args:
keyword arguments that aer passed tothe neural network
"""
time = jnp.full((theta.shape[0], 1), time)
return self._network(theta=theta, time=time, context=context, **kwargs)


# pylint: disable=too-many-arguments
class _ResnetBlock(hk.Module):
"""A block for a 1d residual network."""

def __init__(
self,
hidden_size: int,
activation: Callable = jax.nn.relu,
dropout_rate: float = 0.2,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.1,
):
super().__init__()
self.hidden_size = hidden_size
self.activation = activation
self.do_batch_norm = do_batch_norm
self.dropout_rate = dropout_rate
self.batch_norm_decay = batch_norm_decay

def __call__(self, inputs, context, is_training=False):
outputs = inputs
if self.do_batch_norm:
outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(
outputs, is_training=is_training
)
outputs = hk.Linear(self.hidden_size)(outputs)
outputs = self.activation(outputs)
if is_training:
outputs = hk.dropout(
rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs
)
outputs = hk.Linear(self.hidden_size)(outputs)
context_proj = hk.Linear(inputs.shape[-1])(context)
outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1))
return outputs + inputs


# pylint: disable=too-many-arguments
class _CMResnet(hk.Module):
"""A simplified 1-d residual network."""

def __init__(
self,
n_layers: int,
n_dimension: int,
hidden_size: int,
activation: Callable = jax.nn.relu,
dropout_rate: float = 0.0,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.1,
eps: float = 0.001,
sigma_data:float = 1.0
):
super().__init__()
self.n_layers = n_layers
self.n_dimension = n_dimension
self.hidden_size = hidden_size
self.activation = activation
self.do_batch_norm = do_batch_norm
self.dropout_rate = dropout_rate
self.batch_norm_decay = batch_norm_decay
self.sigma_data = sigma_data
self.var_data = self.sigma_data**2
self.eps = eps

def __call__(self, theta, time, context, is_training=False, **kwargs):
outputs = context
t_theta_embedding = jnp.concatenate(
[
hk.Linear(self.n_dimension)(theta),
hk.Linear(self.n_dimension)(time),
],
axis=-1,
)
outputs = hk.Linear(self.hidden_size)(outputs)
outputs = self.activation(outputs)
for _ in range(self.n_layers):
outputs = _ResnetBlock(
hidden_size=self.hidden_size,
activation=self.activation,
dropout_rate=self.dropout_rate,
do_batch_norm=self.do_batch_norm,
batch_norm_decay=self.batch_norm_decay,
)(outputs, context=t_theta_embedding, is_training=is_training)
outputs = self.activation(outputs)
outputs = hk.Linear(self.n_dimension)(outputs)

# TODO(simon): how is sigma_data chosen automatically?
# in the meantime set it to 1 and use batch norm before
#outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(outputs, is_training=is_training)
out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs
return out_skip

def _c_skip(self, time):
return self.var_data / ((time - self.eps) ** 2 + self.var_data)

def _c_out(self, time):
return (
self.sigma_data
* (time - self.eps)
/ jnp.sqrt(self.var_data + time**2)
)


def make_consistency_model(
n_dimension: int,
n_layers: int = 2,
hidden_size: int = 64,
activation: Callable = jax.nn.tanh,
dropout_rate: float = 0.2,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.2,
t_max: float=50,
epsilon=0.001,
sigma_data:float=1.0
):
"""Create a conditional continuous normalizing flow.
The CCNF uses a residual network as transformer which is created
automatically.
Args:
n_dimension: dimensionality of modelled space
n_layers: number of resnet blocks
hidden_size: sizes of hidden layers for each resnet block
activation: a jax activation function
dropout_rate: dropout rate to use in resnet blocks
do_batch_norm: use batch normalization or not
batch_norm_decay: decay rate of EMA in batch norm layer
Returns:
returns a conditional continuous normalizing flow
"""

@hk.transform
def _cm(method, **kwargs):
nn = _CMResnet(
n_layers=n_layers,
n_dimension=n_dimension,
hidden_size=hidden_size,
activation=activation,
do_batch_norm=do_batch_norm,
dropout_rate=dropout_rate,
batch_norm_decay=batch_norm_decay,
eps=epsilon,
sigma_data=sigma_data,
)
cm = ConsistencyModel(n_dimension, nn, t_max=t_max)
return cm(method, **kwargs)

return _cm
Loading

0 comments on commit 41dd5b0

Please sign in to comment.