Skip to content

Commit

Permalink
Add a test that runs everything (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Jul 28, 2023
1 parent 594b6d8 commit 7d85212
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 28 deletions.
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ repos:
language: python
language_version: python3
types: [python]
args: ["-c", "pyproject.toml"]
additional_dependencies: ["toml"]
files: "(sbijax|examples)"

- repo: https://github.com/PyCQA/flake8
Expand All @@ -58,6 +60,17 @@ repos:
args: ["--ignore-missing-imports"]
files: "(sbijax|examples)"

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
- id: nbqa-pyupgrade
args: [--py39-plus]
- id: nbqa-isort
args: ['--profile=black']
- id: nbqa-flake8
args: ['--ignore=E501,E203,E302,E402,E731,W503']

- repo: https://github.com/jorisroovers/gitlint
rev: v0.18.0
hooks:
Expand Down
26 changes: 2 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)
[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)

> Simulation-based inference in JAX
## About

SbiJAX implements several algorithms for simulation-based inference using
[BlackJAX](https://github.com/blackjax-devs/blackjax), [Haiku](https://github.com/deepmind/dm-haiku) and [JAX](https://github.com/google/jax).
[JAX](https://github.com/google/jax), [Haiku](https://github.com/deepmind/dm-haiku) and [BlackJAX](https://github.com/blackjax-devs/blackjax).

SbiJAX so far implements

Expand Down Expand Up @@ -37,29 +38,6 @@ To install the latest GitHub <RELEASE>, use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
```

## Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
["good first issue"](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). In order to contribute:

1) Fork the repository and install `hatch` and `pre-commit`

```bash
pip install hatch pre-commit
pre-commit install
```

2) Create a new branch in your fork and implement your contribution

3) Test your contribution/implementation by calling `hatch run test` on the (Unix) command line before submitting a PR

```bash
hatch run test:lint
hatch run test:test
```

4) Submit a pull request :slightly_smiling_face:

## Author

Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ dependencies = [
"dm-haiku>=0.0.9",
"flax>=0.6.3",
"optax>=0.1.3",
"surjectors@git+https://[email protected]/dirmeier/[email protected]",
]
dynamic = ["version"]

[project.urls]
homepage = "https://github.com/dirmeier/sbijax"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.version]
path = "sbijax/__init__.py"

Expand All @@ -50,7 +54,7 @@ dependencies = [

[tool.hatch.envs.test.scripts]
lint = 'pylint sbijax'
test = 'pytest -v --doctest-modules --cov=./sbi --cov-report=xml sbijax'
test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax'


[tool.black]
Expand Down
2 changes: 1 addition & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.0.10"
__version__ = "0.0.11"


from sbijax.abc.rejection_abc import RejectionABC
Expand Down
2 changes: 2 additions & 0 deletions sbijax/snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import optax
from absl import logging

# TODO(simon): this is a bit an annoying dependency to have
from flax.training.early_stopping import EarlyStopping
from jax import numpy as jnp

Expand Down
78 changes: 76 additions & 2 deletions sbijax/snl_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,80 @@
# pylint: skip-file
import chex

import distrax
import haiku as hk
import optax
from jax import numpy as jnp
from surjectors import Chain, MaskedCoupling, TransformedDistribution
from surjectors.conditioners import mlp_conditioner
from surjectors.util import make_alternating_binary_mask

from sbijax import SNL


def prior_model_fns():
p = distrax.Independent(
distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)), 1
)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta))
y = p.sample(seed=seed)
return y


def log_density_fn(theta, y):
prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0))
likelihood = distrax.MultivariateNormalDiag(
theta, 0.1 * jnp.ones_like(theta)
)

lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y))
return lp


def make_model(dim):
def _bijector_fn(params):
means, log_scales = jnp.split(params, 2, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))

def _flow(method, **kwargs):
layers = []
for i in range(2):
mask = make_alternating_binary_mask(dim, i % 2 == 0)
layer = MaskedCoupling(
mask=mask,
bijector=_bijector_fn,
conditioner=mlp_conditioner([8, 8, dim * 2]),
)
layers.append(layer)
chain = Chain(layers)
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(dim), jnp.ones(dim)),
1,
)
td = TransformedDistribution(base_distribution, chain)
return td(method, **kwargs)

td = hk.transform(_flow)
td = hk.without_apply_rng(td)
return td


def test_snl():
chex.assert_equal(1, 1)
rng_seq = hk.PRNGSequence(0)
y_observed = jnp.array([-1.0, 1.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snl = SNL(fns, make_model(2))
params, info = snl.fit(
next(rng_seq),
y_observed,
n_rounds=1,
optimizer=optax.adam(1e-4),
sampler="slice",
)
_ = snl.sample_posterior(params, 2, 100, 50, sampler="slice")

0 comments on commit 7d85212

Please sign in to comment.