Skip to content

Commit

Permalink
more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 26, 2024
1 parent a998f42 commit 8d7d715
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
13 changes: 13 additions & 0 deletions sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
from abc import ABC

import chex
Expand Down Expand Up @@ -63,6 +64,18 @@ def simulate_data_and_possibly_append(
d_new = self.stack_data(data, new_data)
return d_new, diagnostics

@abc.abstractmethod
def sample_posterior(self, rng_key, params, observable, *args, **kwargs):
"""Sample from the approximate posterior.
Args:
rng_key: a jax random key
params: a pytree of neural network parameters
observable: a data point
*args: argument list
**kwargs: keyword arguments
"""

def simulate_data(
self,
rng_key,
Expand Down
1 change: 1 addition & 0 deletions sbijax/_src/snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class SNL(SNE):
arXiv preprint arXiv:2308.01054, 2023.
"""

# pylint: disable=useless-parent-delegation
def __init__(self, model_fns, density_estimator):
"""Construct a SNL object.
Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ def simulate_data_and_possibly_append(

def sample_posterior(
self,
rng_key: Array,
params: Optional[Params],
observable: Optional[Array],
rng_key,
params,
observable,
*,
n_chains=4,
n_samples=2_000,
Expand Down

0 comments on commit 8d7d715

Please sign in to comment.