Skip to content

Commit

Permalink
Update API (#14)
Browse files Browse the repository at this point in the history
* Refactor
* Increment API
* Fix lints
  • Loading branch information
dirmeier committed Oct 3, 2023
1 parent 8f42e19 commit 5b6f78b
Show file tree
Hide file tree
Showing 22 changed files with 655 additions and 864 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ SbiJAX so far implements
- Rejection ABC (`RejectionABC`),
- Sequential Monte Carlo ABC (`SMCABC`),
- Sequential Neural Likelihood Estimation (`SNL`)
- Surjective Sequential Neural Likelihood Estimation (`SSNL`)
- Sequential Neural Posterior Estimation C (short `SNP`)

## Examples

You can find several self-contained examples on how to use the algorithms in `examples`.

## Usage

## Installation

Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
Expand Down
6 changes: 4 additions & 2 deletions examples/bivariate_gaussian_smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr

from sbijax import SMCABC

Expand Down Expand Up @@ -42,8 +43,9 @@ def run():
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

smc = SMCABC(fns, summary_fn, distance_fn)
smc.fit(23, y_observed)
smc_samples, _ = smc.sample_posterior(10, 1000, 1000, 0.8, 500)
smc_samples, _ = smc.sample_posterior(
jr.PRNGKey(22), y_observed, 10, 1000, 1000, 0.6, 500
)

fig, axes = plt.subplots(2)
for i in range(2):
Expand Down
40 changes: 20 additions & 20 deletions examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random
from jax import random as jr
from surjectors import (
Chain,
MaskedAutoregressive,
Expand All @@ -31,16 +31,14 @@ def prior_model_fns():


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 0.1)
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y


def log_density_fn(theta, y):
prior = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
likelihood = distrax.MultivariateNormalDiag(
theta, 0.1 * jnp.ones_like(theta)
)
likelihood = distrax.MultivariateNormalDiag(theta, jnp.ones_like(theta))

lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y))
return lp
Expand Down Expand Up @@ -94,26 +92,28 @@ def run():

snl = SNL(fns, make_model(2))
optimizer = optax.adam(1e-3)
params, info = snl.fit(
random.PRNGKey(23),
y_observed,
optimizer=optimizer,
n_rounds=3,
max_n_iter=100,
batch_size=64,
n_early_stopping_patience=5,
sampler="slice",
)

data, params = None, {}
for i in range(2):
data, _ = snl.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(12), i),
params=params,
observable=y_observed,
data=data,
)
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
)

sample_key, rng_key = jr.split(jr.PRNGKey(123))
slice_samples = sample_with_slice(
hk.PRNGSequence(0), log_density, 4, 2000, 1000, prior_simulator_fn
sample_key, log_density, prior_simulator_fn
)
slice_samples = slice_samples.reshape(-1, 2)
snl_samples, _ = snl.sample_posterior(
params, 4, 2000, 1000, sampler="slice"
)

print(f"Took n={snl.n_total_simulations} simulations in total")
sample_key, rng_key = jr.split(rng_key)
snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed)

fig, axes = plt.subplots(2, 2)
for i in range(2):
sns.histplot(
Expand Down
36 changes: 20 additions & 16 deletions examples/bivariate_gaussian_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random
from jax import random as jr
from surjectors import Chain, TransformedDistribution
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.bijectors.permutation import Permutation
Expand All @@ -25,7 +25,7 @@ def prior_model_fns():


def simulator_fn(seed, theta):
p = distrax.Normal(jnp.zeros_like(theta), 0.1)
p = distrax.Normal(jnp.zeros_like(theta), 1.0)
y = theta + p.sample(seed=seed)
return y

Expand Down Expand Up @@ -72,21 +72,25 @@ def run():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

optimizer = optax.adamw(1e-04)
snp = SNP(fns, make_model(2))
params, info = snp.fit(
random.PRNGKey(2),
y_observed,
n_rounds=3,
optimizer=optimizer,
n_early_stopping_patience=10,
batch_size=64,
n_atoms=10,
max_n_iter=100,
)

print(f"Took n={snp.n_total_simulations} simulations in total")
snp_samples, _ = snp.sample_posterior(params, 10000)
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = snp.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, info = snp.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
)

rng_key = jr.PRNGKey(23)
snp_samples, _ = snp.sample_posterior(rng_key, params, y_observed)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(snp_samples[:, i], color="darkblue", ax=ax)
Expand Down
147 changes: 0 additions & 147 deletions examples/slcp_smcabc.py

This file was deleted.

Loading

0 comments on commit 5b6f78b

Please sign in to comment.