Skip to content

Commit

Permalink
Move to ruff (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 29, 2024
1 parent 513e8fd commit 5851493
Show file tree
Hide file tree
Showing 27 changed files with 95 additions and 118 deletions.
38 changes: 5 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(sbijax|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(sbijax|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
Expand All @@ -44,24 +24,16 @@ repos:
additional_dependencies: ["toml"]
files: "(sbijax|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
files: "(sbijax|examples)"

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
53 changes: 13 additions & 40 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"
homepage = "https://github.com/dirmeier/sbijax"
keywords = ["abc", "simulation-based inference", "approximate Bayesian computation", "normalizing flows", "smc-abc"]
classifiers = [
"Development Status :: 1 - Planning",
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -57,54 +57,27 @@ dependencies = [

[tool.hatch.envs.test]
dependencies = [
"pylint>=2.15.10",
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0"
]

[tool.hatch.envs.test.scripts]
lint = 'pylint sbijax'
lint = 'ruff check sbijax'
test = 'pytest -v --cov=./sbijax --cov-report=xml sbijax'

[tool.black]
line-length = 80
target-version = ['py311']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.bandit]
skips = ["B101"]

[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true
[tool.ruff]
line-length = 80
exclude = ["*_test.py", "docs/**", "examples/**"]

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503", "E731", "E231"]
per-file-ignores = [
'__init__.py:F401',
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
]

[tool.pylint.messages_control]
disable = """
invalid-name,missing-module-docstring,R0801,E0633
"""

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention= 'google'
match = '^sbijax/.*/((?!_test).)*\.py'
7 changes: 3 additions & 4 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""
sbijax: Simulation-based inference in JAX
"""
"""sbijax: Simulation-based inference in JAX."""

__version__ = "0.2.0"


from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.scmpe import SCMPE
from sbijax._src.sfmpe import SFMPE
Expand All @@ -13,3 +10,5 @@
from sbijax._src.snl import SNL
from sbijax._src.snp import SNP
from sbijax._src.snr import SNR

__all__ = ["SMCABC", "SCMPE", "SFMPE", "SNASS", "SNASSS", "SNL", "SNP", "SNR"]
3 changes: 1 addition & 2 deletions sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from sbijax._src.util.dataloader import as_batch_iterators, named_dataset


# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-function-args,arguments-differ
# ruff: noqa: PLR0913
class SNE(SBI, ABC):
"""Sequential neural estimation base class."""

Expand Down
7 changes: 3 additions & 4 deletions sbijax/_src/abc/smc_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from sbijax._src._sbi_base import SBI


# pylint: disable=arguments-differ,too-many-function-args,too-many-locals
# pylint: disable=too-few-public-methods
# ruff: noqa: PLR0913
class SMCABC(SBI):
"""Sequential Monte Carlo approximate Bayesian computation.
Expand All @@ -38,7 +37,6 @@ def __init__(self, model_fns, summary_fn, distance_fn):
self.summarized_observed: chex.Array
self.n_total_simulations = 0

# pylint: disable=too-many-arguments,arguments-differ
def sample_posterior(
self,
rng_key,
Expand All @@ -53,9 +51,10 @@ def sample_posterior(
r"""Sample from the approximate posterior.
Args:
rng_key: a jax random
n_rounds: max number of SMC rounds
observable: the observation to condition on
n_round: number of rounds of SMC
n_rounds: number of rounds of SMC
n_particles: number of n_particles to draw for each parameter
n_simulations_per_theta: number of simulations for each paramrter
sample
Expand Down
9 changes: 9 additions & 0 deletions sbijax/_src/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,12 @@
from sbijax._src.mcmc.nuts import sample_with_nuts
from sbijax._src.mcmc.rmh import sample_with_rmh
from sbijax._src.mcmc.slice import sample_with_slice

__all__ = [
"mcmc_diagnostics",
"sample_with_slice",
"sample_with_nuts",
"sample_with_mala",
"sample_with_rmh",
"sample_with_imh",
]
4 changes: 2 additions & 2 deletions sbijax/_src/mcmc/irmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_imh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
r"""Draw samples using the indepdendent Metropolis-Hastings sampler.
Args:
rng_seq: a hk.PRNGSequence
rng_key: a jax random key
lp: the logdensity you wish to sample from
prior: a function that returns a prior sample
n_chains: number of chains to sample
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_mala(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_nuts(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/rmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_rmh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
Expand Down
4 changes: 2 additions & 2 deletions sbijax/_src/mcmc/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
# ruff: noqa: PLR0913,D417
def sample_with_slice(
rng_key,
lp,
Expand All @@ -20,7 +20,7 @@ def sample_with_slice(
r"""Sample from a distribution using the No-U-Turn sampler.
Args:
rng_seq: a hk.PRNGSequence
rng_key: a jax random key
lp: the logdensity you wish to sample from
prior: a function that returns a prior sample
n_chains: number of chains to sample
Expand Down
2 changes: 2 additions & 0 deletions sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
__all__ = ["ConsistencyModel", "make_consistency_model"]


# ruff: noqa: PLR0913,D417
class ConsistencyModel(hk.Module):
"""A consistency model.
Expand Down Expand Up @@ -161,6 +162,7 @@ def _c_out(self, time):
)


# ruff: noqa: PLR0913
def make_consistency_model(
n_dimension: int,
n_layers: int = 2,
Expand Down
2 changes: 2 additions & 0 deletions sbijax/_src/nn/continuous_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
__all__ = ["CCNF", "make_ccnf"]


# ruff: noqa: PLR0913,D417
class CCNF(hk.Module):
"""Conditional continuous normalizing flow.
Expand Down Expand Up @@ -180,6 +181,7 @@ def __call__(self, theta, time, context, is_training=False, **kwargs):
return outputs


# ruff: noqa: PLR0913
def make_ccnf(
n_dimension: int,
n_layers: int = 2,
Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/nn/make_flows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import distrax
import haiku as hk
Expand All @@ -11,8 +12,7 @@
Permutation,
TransformedDistribution,
)
from surjectors._src.conditioners.mlp import make_mlp
from surjectors._src.conditioners.nn.made import MADE
from surjectors.nn import MADE, make_mlp
from surjectors.util import unstack


Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/nn/make_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __call__(self, inputs, is_training=False):
return outputs + inputs


# pylint: disable=too-many-arguments
# ruff: noqa: PLR0913
class _Resnet(hk.Module):
"""A simplified 1-d residual network."""

Expand Down
3 changes: 2 additions & 1 deletion sbijax/_src/nn/make_snass_networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import haiku as hk
import jax
Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/nn/snass_net.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import haiku as hk
import jax
from jax import numpy as jnp


# pylint: disable=missing-function-docstring,missing-class-docstring
# pydocstyle: disable=D102
# ruff: noqa: PLR0913,S101
class SNASSNet(hk.Module):
"""A network for SNASS."""

Expand Down
6 changes: 3 additions & 3 deletions sbijax/_src/nn/snasss_net.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Iterable
from collections.abc import Iterable
from typing import Callable

import haiku as hk
import jax
Expand All @@ -7,8 +8,7 @@
from sbijax._src.nn.snass_net import SNASSNet


# pylint: disable=missing-function-docstring,missing-class-docstring
# pylint: disable=too-many-arguments
# ruff: noqa: PLR0913,S101
class SNASSSNet(SNASSNet):
"""A network for SNASSS."""

Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/scmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _consistency_loss(
return jnp.mean(loss)


# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
# ruff: noqa: PLR0913
class SCMPE(SFMPE):
r"""Sequential consistency model posterior estimation.
Expand Down
5 changes: 3 additions & 2 deletions sbijax/_src/sfmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _cfm_loss(
return loss


# pylint: disable=too-many-arguments,unused-argument,useless-parent-delegation
# ruff: noqa: PLR0913
class SFMPE(SNE):
r"""Sequential flow matching posterior estimation.
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, model_fns, density_estimator):
"""
super().__init__(model_fns, density_estimator)

# pylint: disable=arguments-differ,too-many-locals
# ruff: noqa: D417
def fit(
self,
rng_key,
Expand Down Expand Up @@ -233,6 +233,7 @@ def body_fn(batch_key, **batch):
loss += body_fn(val_key, **batch)
return loss

# ruff: noqa: D417
def sample_posterior(
self, rng_key, params, observable, *, n_samples=4_000, **kwargs
):
Expand Down
Loading

0 comments on commit 5851493

Please sign in to comment.