Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move to ruff #28

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(sbijax|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(sbijax|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
Expand All @@ -44,24 +24,16 @@ repos:
additional_dependencies: ["toml"]
files: "(sbijax|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
files: "(sbijax|examples)"

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
53 changes: 13 additions & 40 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"
homepage = "https://github.com/dirmeier/sbijax"
keywords = ["abc", "simulation-based inference", "approximate Bayesian computation", "normalizing flows", "smc-abc"]
classifiers = [
"Development Status :: 1 - Planning",
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -57,54 +57,27 @@ dependencies = [

[tool.hatch.envs.test]
dependencies = [
"pylint>=2.15.10",
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0"
]

[tool.hatch.envs.test.scripts]
lint = 'pylint sbijax'
lint = 'ruff check sbijax'
test = 'pytest -v --cov=./sbijax --cov-report=xml sbijax'

[tool.black]
line-length = 80
target-version = ['py311']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.bandit]
skips = ["B101"]

[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true
[tool.ruff]
line-length = 80
exclude = ["*_test.py", "docs/**", "examples/**"]

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503", "E731", "E231"]
per-file-ignores = [
'__init__.py:F401',
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
]

[tool.pylint.messages_control]
disable = """
invalid-name,missing-module-docstring,R0801,E0633
"""

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention= 'google'
match = '^sbijax/.*/((?!_test).)*\.py'
7 changes: 3 additions & 4 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""
sbijax: Simulation-based inference in JAX
"""
"""sbijax: Simulation-based inference in JAX."""

__version__ = "0.2.0"


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.scmpe import SCMPE
from sbijax._src.sfmpe import SFMPE
Expand All @@ -13,3 +10,5 @@
from sbijax._src.snl import SNL
from sbijax._src.snp import SNP
from sbijax._src.snr import SNR

__all__ = ["SMCABC", "SCMPE", "SFMPE", "SNASS", "SNASSS", "SNL", "SNP", "SNR"]
3 changes: 1 addition & 2 deletions sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from sbijax._src.util.dataloader import as_batch_iterators, named_dataset


# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-function-args,arguments-differ
# ruff: noqa: PLR0913
class SNE(SBI, ABC):
"""Sequential neural estimation base class."""

Expand Down
7 changes: 3 additions & 4 deletions sbijax/_src/abc/smc_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from sbijax._src._sbi_base import SBI


# pylint: disable=arguments-differ,too-many-function-args,too-many-locals
# pylint: disable=too-few-public-methods
# ruff: noqa: PLR0913
class SMCABC(SBI):
"""Sequential Monte Carlo approximate Bayesian computation.

Expand All @@ -38,7 +37,6 @@ def __init__(self, model_fns, summary_fn, distance_fn):
self.summarized_observed: chex.Array
self.n_total_simulations = 0

# pylint: disable=too-many-arguments,arguments-differ
def sample_posterior(
self,
rng_key,
Expand All @@ -53,9 +51,10 @@ def sample_posterior(
r"""Sample from the approximate posterior.

Args:
rng_key: a jax random
n_rounds: max number of SMC rounds
observable: the observation to condition on
n_round: number of rounds of SMC
n_rounds: number of rounds of SMC
n_particles: number of n_particles to draw for each parameter
n_simulations_per_theta: number of simulations for each paramrter
sample
Expand Down
9 changes: 9 additions & 0 deletions sbijax/_src/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,12 @@
from sbijax._src.mcmc.nuts import sample_with_nuts
from sbijax._src.mcmc.rmh import sample_with_rmh
from sbijax._src.mcmc.slice import sample_with_slice

__all__ = [
"mcmc_diagnostics",
"sample_with_slice",
"sample_with_nuts",
"sample_with_mala",
"sample_with_rmh",
"sample_with_imh",
]
4 changes: 2 additions & 2 deletions sbijax/_src/mcmc/irmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_imh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
r"""Draw samples using the indepdendent Metropolis-Hastings sampler.

Args:
rng_seq: a hk.PRNGSequence
rng_key: a jax random key
lp: the logdensity you wish to sample from
prior: a function that returns a prior sample
n_chains: number of chains to sample
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_mala(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_nuts(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/rmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_rmh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
Expand Down
4 changes: 2 additions & 2 deletions sbijax/_src/mcmc/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_slice(
rng_key,
lp,
Expand All @@ -20,7 +20,7 @@ def sample_with_slice(
r"""Sample from a distribution using the No-U-Turn sampler.

Args:
rng_seq: a hk.PRNGSequence
rng_key: a jax random key
lp: the logdensity you wish to sample from
prior: a function that returns a prior sample
n_chains: number of chains to sample
Expand Down
2 changes: 2 additions & 0 deletions sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
__all__ = ["ConsistencyModel", "make_consistency_model"]


# ruff: noqa: PLR0913,D417
class ConsistencyModel(hk.Module):
"""A consistency model.

Expand Down Expand Up @@ -161,6 +162,7 @@ def _c_out(self, time):
)


# ruff: noqa: PLR0913
def make_consistency_model(
n_dimension: int,
n_layers: int = 2,
Expand Down
2 changes: 2 additions & 0 deletions sbijax/_src/nn/continuous_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
__all__ = ["CCNF", "make_ccnf"]


# ruff: noqa: PLR0913,D417
class CCNF(hk.Module):
"""Conditional continuous normalizing flow.

Expand Down Expand Up @@ -180,6 +181,7 @@ def __call__(self, theta, time, context, is_training=False, **kwargs):
return outputs


# ruff: noqa: PLR0913
def make_ccnf(
n_dimension: int,
n_layers: int = 2,
Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/nn/make_flows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import distrax
import haiku as hk
Expand All @@ -11,8 +12,7 @@
Permutation,
TransformedDistribution,
)
from surjectors._src.conditioners.mlp import make_mlp
from surjectors._src.conditioners.nn.made import MADE
from surjectors.nn import MADE, make_mlp
from surjectors.util import unstack


Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/nn/make_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __call__(self, inputs, is_training=False):
return outputs + inputs


# pylint: disable=too-many-arguments
# ruff: noqa: PLR0913
class _Resnet(hk.Module):
"""A simplified 1-d residual network."""

Expand Down
3 changes: 2 additions & 1 deletion sbijax/_src/nn/make_snass_networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import haiku as hk
import jax
Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/nn/snass_net.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import haiku as hk
import jax
from jax import numpy as jnp


# pylint: disable=missing-function-docstring,missing-class-docstring
# pydocstyle: disable=D102
# ruff: noqa: PLR0913,S101
class SNASSNet(hk.Module):
"""A network for SNASS."""

Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/nn/snasss_net.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import haiku as hk
import jax
Expand All @@ -7,8 +8,7 @@
from sbijax._src.nn.snass_net import SNASSNet


# pylint: disable=missing-function-docstring,missing-class-docstring
# pylint: disable=too-many-arguments
# ruff: noqa: PLR0913,S101
class SNASSSNet(SNASSNet):
"""A network for SNASSS."""

Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/scmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _consistency_loss(
return jnp.mean(loss)


# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
# ruff: noqa: PLR0913
class SCMPE(SFMPE):
r"""Sequential consistency model posterior estimation.

Expand Down
5 changes: 3 additions & 2 deletions sbijax/_src/sfmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _cfm_loss(
return loss


# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
# ruff: noqa: PLR0913
class SFMPE(SNE):
r"""Sequential flow matching posterior estimation.

Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, model_fns, density_estimator):
"""
super().__init__(model_fns, density_estimator)

# pylint: disable=arguments-differ,too-many-locals
# ruff: noqa: D417
def fit(
self,
rng_key,
Expand Down Expand Up @@ -233,6 +233,7 @@ def body_fn(batch_key, **batch):
loss += body_fn(val_key, **batch)
return loss

# ruff: noqa: D417
def sample_posterior(
self, rng_key, params, observable, *, n_samples=4_000, **kwargs
):
Expand Down
Loading
Loading