diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac38ca37..24d0e22d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.14 + rev: v0.15.4 hooks: - id: ruff-check args: [ --fix ] diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index a197d151..e956dc0c 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -327,13 +327,10 @@ def _sample_zero(n_photons_data): 1.0, rng, ), - lambda flux, - eta_factor, - max_sb, - poisson_flux, - max_extra_noise, - rng: _calculate_n_photons_flux_nonzero( - flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: ( + _calculate_n_photons_flux_nonzero( + flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng + ) ), n_photons_data.flux, n_photons_data.flux_per_photon, diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index ddb4e3bb..51325ea3 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -374,8 +374,9 @@ def _applyTo(self, image): def _getVariance(self): return jax.lax.cond( self.gain > 0.0, - lambda gain, sky_level, read_noise: sky_level / gain - + (read_noise / gain) ** 2, + lambda gain, sky_level, read_noise: ( + sky_level / gain + (read_noise / gain) ** 2 + ), lambda gain, sky_level, read_noise: read_noise**2, self.gain, self.sky_level, diff --git a/tests/jax/test_interpolant_jax.py b/tests/jax/test_interpolant_jax.py index 2be425c1..45d328fb 100644 --- a/tests/jax/test_interpolant_jax.py +++ b/tests/jax/test_interpolant_jax.py @@ -141,12 +141,14 @@ def _timeit(lz, ntest=10, jit=False, dox=False): ] + [galsim.Lanczos(i, conserve_dc=False) for i in range(1, 31)] + [galsim.Lanczos(i, conserve_dc=True) for i in range(1, 31)], - ids=lambda x: str(x) - .replace("galsim.", "") - .replace("(", "") - .replace(")", "") - .replace(", ", "-") - + ("" if not isinstance(x, galsim.Lanczos) else f"-{x.conserve_dc}"), + ids=lambda x: ( + str(x) + .replace("galsim.", "") + .replace("(", "") + .replace(")", "") + .replace(", ", "-") + + ("" if not isinstance(x, galsim.Lanczos) else f"-{x.conserve_dc}") + ), ) @pytest.mark.parametrize("kind", ["fluxes", "ranges", "xval", "kval"]) def test_interpolant_jax_same_as_galsim(interp, kind):