-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
688 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Example using consistency model posterior estimation on a bivariate Gaussian | ||
""" | ||
|
||
import distrax | ||
import haiku as hk | ||
import matplotlib.pyplot as plt | ||
import optax | ||
import seaborn as sns | ||
from jax import numpy as jnp | ||
from jax import random as jr | ||
|
||
from sbijax import SCMPE | ||
from sbijax.nn import ConsistencyModel | ||
|
||
|
||
def prior_model_fns(): | ||
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) | ||
return p.sample, p.log_prob | ||
|
||
|
||
def simulator_fn(seed, theta): | ||
p = distrax.Normal(jnp.zeros_like(theta), 1.0) | ||
y = theta + p.sample(seed=seed) | ||
return y | ||
|
||
|
||
def make_model(dim): | ||
@hk.transform | ||
def _mlp(method, **kwargs): | ||
def _c_skip(time): | ||
return 1 / ((time - 0.001) ** 2 + 1) | ||
|
||
def _c_out(time): | ||
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time**2) | ||
|
||
def _nn(theta, time, context, **kwargs): | ||
ins = jnp.concatenate([theta, time, context], axis=-1) | ||
outs = hk.nets.MLP([64, 64, dim])(ins) | ||
out_skip = _c_skip(time) * theta + _c_out(time) * outs | ||
return out_skip | ||
|
||
cm = ConsistencyModel(dim, _nn) | ||
return cm(method, **kwargs) | ||
|
||
return _mlp | ||
|
||
|
||
def run(): | ||
y_observed = jnp.array([2.0, -2.0]) | ||
|
||
prior_simulator_fn, prior_logdensity_fn = prior_model_fns() | ||
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn | ||
|
||
estim = SCMPE(fns, make_model(2)) | ||
optimizer = optax.adam(1e-3) | ||
|
||
data, params = None, {} | ||
for i in range(2): | ||
data, _ = estim.simulate_data_and_possibly_append( | ||
jr.fold_in(jr.PRNGKey(1), i), | ||
params=params, | ||
observable=y_observed, | ||
data=data, | ||
) | ||
params, info = estim.fit( | ||
jr.fold_in(jr.PRNGKey(2), i), | ||
data=data, | ||
optimizer=optimizer, | ||
) | ||
|
||
rng_key = jr.PRNGKey(23) | ||
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) | ||
print(post_samples) | ||
fig, axes = plt.subplots(2) | ||
for i, ax in enumerate(axes): | ||
sns.histplot(post_samples[:, i], color="darkblue", ax=ax) | ||
ax.set_xlim([-3.0, 3.0]) | ||
sns.despine() | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
from typing import Callable | ||
|
||
import distrax | ||
import haiku as hk | ||
import jax | ||
from jax import numpy as jnp | ||
|
||
from sbijax._src.nn.continuous_normalizing_flow import _ResnetBlock | ||
|
||
__all__ = ["ConsistencyModel", "make_consistency_model"] | ||
|
||
|
||
class ConsistencyModel(hk.Module): | ||
"""A consistency model. | ||
Args: | ||
n_dimension: the dimensionality of the modelled space | ||
transform: a haiku module. The transform is a callable that has to | ||
take as input arguments named 'theta', 'time', 'context' and | ||
**kwargs. Theta, time and context are two-dimensional arrays | ||
with the same batch dimensions. | ||
t_min: minimal time point for ODE integration | ||
t_max: maximal time point for ODE integration | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_dimension: int, | ||
transform: Callable, | ||
t_min: float = 0.001, | ||
t_max: float = 50.0, | ||
): | ||
"""Construct a consistency model. | ||
Args: | ||
n_dimension: the dimensionality of the modelled space | ||
transform: a haiku module. The transform is a callable that has to | ||
take as input arguments named 'theta', 'time', 'context' and | ||
**kwargs. Theta, time and context are two-dimensional arrays | ||
with the same batch dimensions. | ||
t_min: minimal time point for ODE integration | ||
t_max: maximal time point for ODE integration | ||
""" | ||
super().__init__() | ||
self._n_dimension = n_dimension | ||
self._network = transform | ||
self._t_max = t_max | ||
self._t_min = t_min | ||
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) | ||
|
||
def __call__(self, method, **kwargs): | ||
"""Aplpy the flow. | ||
Args: | ||
method (str): method to call | ||
Keyword Args: | ||
keyword arguments for the called method: | ||
""" | ||
return getattr(self, method)(**kwargs) | ||
|
||
def sample(self, context, **kwargs): | ||
"""Sample from the consistency model. | ||
Args: | ||
context: array of conditioning variables | ||
kwargs: keyword argumente like 'is_training' | ||
""" | ||
noise = self._base_distribution.sample( | ||
seed=hk.next_rng_key(), sample_shape=(context.shape[0],) | ||
) | ||
y_hat = self.vector_field(noise, self._t_max, context, **kwargs) | ||
|
||
noise = self._base_distribution.sample( | ||
seed=hk.next_rng_key(), sample_shape=(y_hat.shape[0],) | ||
) | ||
tme = self._t_min + (self._t_max - self._t_min) / 2 | ||
noise = jnp.sqrt(jnp.square(tme) - jnp.square(self._t_min)) * noise | ||
y_tme = y_hat + noise | ||
y_hat = self.vector_field(y_tme, tme, context, **kwargs) | ||
|
||
return y_hat | ||
|
||
def vector_field(self, theta, time, context, **kwargs): | ||
"""Compute the vector field. | ||
Args: | ||
theta: array of parameters | ||
time: time variables | ||
context: array of conditioning variables | ||
Keyword Args: | ||
keyword arguments that aer passed tothe neural network | ||
""" | ||
time = jnp.full((theta.shape[0], 1), time) | ||
return self._network(theta=theta, time=time, context=context, **kwargs) | ||
|
||
|
||
# pylint: disable=too-many-arguments,too-many-instance-attributes | ||
class _CMResnet(hk.Module): | ||
"""A simplified 1-d residual network.""" | ||
|
||
def __init__( | ||
self, | ||
n_layers: int, | ||
n_dimension: int, | ||
hidden_size: int, | ||
activation: Callable = jax.nn.relu, | ||
dropout_rate: float = 0.0, | ||
do_batch_norm: bool = False, | ||
batch_norm_decay: float = 0.1, | ||
t_min: float = 0.001, | ||
sigma_data: float = 1.0, | ||
): | ||
super().__init__() | ||
self.n_layers = n_layers | ||
self.n_dimension = n_dimension | ||
self.hidden_size = hidden_size | ||
self.activation = activation | ||
self.do_batch_norm = do_batch_norm | ||
self.dropout_rate = dropout_rate | ||
self.batch_norm_decay = batch_norm_decay | ||
self.sigma_data = sigma_data | ||
self.var_data = self.sigma_data**2 | ||
self.t_min = t_min | ||
|
||
def __call__(self, theta, time, context, is_training, **kwargs): | ||
outputs = context | ||
t_theta_embedding = jnp.concatenate( | ||
[ | ||
hk.Linear(self.n_dimension)(theta), | ||
hk.Linear(self.n_dimension)(time), | ||
], | ||
axis=-1, | ||
) | ||
outputs = hk.Linear(self.hidden_size)(outputs) | ||
outputs = self.activation(outputs) | ||
for _ in range(self.n_layers): | ||
outputs = _ResnetBlock( | ||
hidden_size=self.hidden_size, | ||
activation=self.activation, | ||
dropout_rate=self.dropout_rate, | ||
do_batch_norm=self.do_batch_norm, | ||
batch_norm_decay=self.batch_norm_decay, | ||
)(outputs, context=t_theta_embedding, is_training=is_training) | ||
outputs = self.activation(outputs) | ||
outputs = hk.Linear(self.n_dimension)(outputs) | ||
|
||
# TODO(simon): dan we choose sigma automatically? | ||
out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs | ||
return out_skip | ||
|
||
def _c_skip(self, time): | ||
return self.var_data / ((time - self.t_min) ** 2 + self.var_data) | ||
|
||
def _c_out(self, time): | ||
return ( | ||
self.sigma_data | ||
* (time - self.t_min) | ||
/ jnp.sqrt(self.var_data + time**2) | ||
) | ||
|
||
|
||
def make_consistency_model( | ||
n_dimension: int, | ||
n_layers: int = 2, | ||
hidden_size: int = 64, | ||
activation: Callable = jax.nn.tanh, | ||
dropout_rate: float = 0.2, | ||
do_batch_norm: bool = False, | ||
batch_norm_decay: float = 0.2, | ||
t_min: float = 0.001, | ||
t_max: float = 50.0, | ||
sigma_data: float = 1.0, | ||
): | ||
"""Create a consistency model. | ||
The consistency model uses a residual network as score network. | ||
Args: | ||
n_dimension: dimensionality of modelled space | ||
n_layers: number of resnet blocks | ||
hidden_size: sizes of hidden layers for each resnet block | ||
activation: a jax activation function | ||
dropout_rate: dropout rate to use in resnet blocks | ||
do_batch_norm: use batch normalization or not | ||
batch_norm_decay: decay rate of EMA in batch norm layer | ||
t_min: minimal time point for ODE integration | ||
t_max: maximal time point for ODE integration | ||
sigma_data: the standard deviation of the data :) | ||
Returns: | ||
returns a consistency model | ||
""" | ||
|
||
@hk.transform | ||
def _cm(method, **kwargs): | ||
nn = _CMResnet( | ||
n_layers=n_layers, | ||
n_dimension=n_dimension, | ||
hidden_size=hidden_size, | ||
activation=activation, | ||
do_batch_norm=do_batch_norm, | ||
dropout_rate=dropout_rate, | ||
batch_norm_decay=batch_norm_decay, | ||
t_min=t_min, | ||
sigma_data=sigma_data, | ||
) | ||
cm = ConsistencyModel(n_dimension, nn, t_min=t_min, t_max=t_max) | ||
return cm(method, **kwargs) | ||
|
||
return _cm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.