Skip to content

Commit

Permalink
Fix (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Oct 4, 2023
1 parent 01f2e2c commit cb7b35d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.1"
__version__ = "0.1.2"


from sbijax.abc.rejection_abc import RejectionABC
Expand Down
4 changes: 4 additions & 0 deletions sbijax/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def stack_data(data, also_data):
returns the stack of the two data sets
"""

if data is None:
return also_data
if also_data is None:
return data
return named_dataset(
*[jnp.vstack([a, b]) for a, b in zip(data, also_data)]
)
Expand Down
13 changes: 13 additions & 0 deletions sbijax/snl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def test_stack_data():
chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:])


def test_stack_data_with_none():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snl = SNL(fns, make_model(2))
n = 100
data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n)
stacked_data = snl.stack_data(None, data)

chex.assert_trees_all_equal(data[0], stacked_data[0])
chex.assert_trees_all_equal(data[1], stacked_data[1])


def test_simulate_data_from_posterior_fail():
rng_seq = hk.PRNGSequence(0)

Expand Down

0 comments on commit cb7b35d

Please sign in to comment.