diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..0162acc1 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,35 @@ +name: Documentation + +on: + push: + branches: [main] + pull_request: + branches: [main] + +permissions: + contents: write + +concurrency: + group: docs-${{ github.ref }} + cancel-in-progress: true + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install dependencies + run: pip install -e ".[docs]" + + - name: Build documentation + if: github.event_name == 'pull_request' + run: mkdocs build --strict + + - name: Deploy documentation + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: mkdocs gh-deploy --force diff --git a/docs/api-coverage.md b/docs/api-coverage.md new file mode 100644 index 00000000..702cb085 --- /dev/null +++ b/docs/api-coverage.md @@ -0,0 +1,128 @@ +# API Coverage + +JAX-GalSim has implemented **22.5%** of the GalSim API. The project focuses on +the most commonly used profiles and operations, with coverage expanding over time. + +## Supported APIs + +??? note "Click to expand the full list of implemented APIs" + + - `galsim.Add` + - `galsim.AffineTransform` + - `galsim.Angle` + - `galsim.AngleUnit` + - `galsim.BaseDeviate` + - `galsim.BaseNoise` + - `galsim.BaseWCS` + - `galsim.BinomialDeviate` + - `galsim.Bounds` + - `galsim.BoundsD` + - `galsim.BoundsI` + - `galsim.Box` + - `galsim.CCDNoise` + - `galsim.CelestialCoord` + - `galsim.Chi2Deviate` + - `galsim.Convolution` + - `galsim.Convolve` + - `galsim.Cubic` + - `galsim.Deconvolution` + - `galsim.Deconvolve` + - `galsim.Delta` + - `galsim.DeltaFunction` + - `galsim.DeviateNoise` + - `galsim.Exponential` + - `galsim.FitsHeader` + - `galsim.FitsWCS` + - `galsim.GSFitsWCS` + - `galsim.GSObject` + - `galsim.GSParams` + - `galsim.GalSimBoundsError` + - `galsim.GalSimConfigError` + - `galsim.GalSimConfigValueError` + - `galsim.GalSimDeprecationWarning` + - `galsim.GalSimError` + - `galsim.GalSimFFTSizeError` + - `galsim.GalSimHSMError` + - `galsim.GalSimImmutableError` + - `galsim.GalSimIncompatibleValuesError` + - `galsim.GalSimIndexError` + - `galsim.GalSimKeyError` + - `galsim.GalSimNotImplementedError` + - `galsim.GalSimRangeError` + - `galsim.GalSimSEDError` + - `galsim.GalSimUndefinedBoundsError` + - `galsim.GalSimValueError` + - `galsim.GalSimWarning` + - `galsim.GammaDeviate` + - `galsim.Gaussian` + - `galsim.GaussianDeviate` + - `galsim.GaussianNoise` + - `galsim.Image` + - `galsim.ImageCD` + - `galsim.ImageCF` + - `galsim.ImageD` + - `galsim.ImageF` + - `galsim.ImageI` + - `galsim.ImageS` + - `galsim.ImageUI` + - `galsim.ImageUS` + - `galsim.Interpolant` + - `galsim.InterpolatedImage` + - `galsim.JacobianWCS` + - `galsim.Lanczos` + - `galsim.Linear` + - `galsim.Moffat` + - `galsim.Nearest` + - `galsim.OffsetShearWCS` + - `galsim.OffsetWCS` + - `galsim.PhotonArray` + - `galsim.Pixel` + - `galsim.PixelScale` + - `galsim.PoissonDeviate` + - `galsim.PoissonNoise` + - `galsim.Position` + - `galsim.PositionD` + - `galsim.PositionI` + - `galsim.Quintic` + - `galsim.Sensor` + - `galsim.Shear` + - `galsim.ShearWCS` + - `galsim.SincInterpolant` + - `galsim.Spergel` + - `galsim.Sum` + - `galsim.TanWCS` + - `galsim.Transform` + - `galsim.Transformation` + - `galsim.UniformDeviate` + - `galsim.VariableGaussianNoise` + - `galsim.WeibullDeviate` + - `galsim.bessel.j0` + - `galsim.bessel.kv` + - `galsim.bessel.si` + - `galsim.fits.closeHDUList` + - `galsim.fits.readCube` + - `galsim.fits.readFile` + - `galsim.fits.readMulti` + - `galsim.fits.write` + - `galsim.fits.writeFile` + - `galsim.fitswcs.CelestialWCS` + - `galsim.integ.int1d` + - `galsim.noise.addNoise` + - `galsim.noise.addNoiseSNR` + - `galsim.random.permute` + - `galsim.utilities.g1g2_to_e1e2` + - `galsim.utilities.horner` + - `galsim.utilities.printoptions` + - `galsim.utilities.unweighted_moments` + - `galsim.utilities.unweighted_shape` + - `galsim.wcs.EuclideanWCS` + - `galsim.wcs.LocalWCS` + - `galsim.wcs.UniformWCS` + +## Updating Coverage + +```bash +python scripts/update_api_coverage.py +``` + +Compares GalSim's public API against `jax_galsim`'s implementations and updates the coverage percentage and list in `README.md`. diff --git a/docs/api/composition/convolve.md b/docs/api/composition/convolve.md new file mode 100644 index 00000000..edfe1c3e --- /dev/null +++ b/docs/api/composition/convolve.md @@ -0,0 +1,7 @@ +# Convolution & Deconvolution + +Convolve profiles together (e.g., galaxy with PSF) or deconvolve. + +::: jax_galsim.convolve.Convolution + +::: jax_galsim.convolve.Deconvolution diff --git a/docs/api/composition/sum.md b/docs/api/composition/sum.md new file mode 100644 index 00000000..5106d845 --- /dev/null +++ b/docs/api/composition/sum.md @@ -0,0 +1,5 @@ +# Sum (Add) + +Add surface brightness profiles together. + +::: jax_galsim.sum.Sum diff --git a/docs/api/composition/transform.md b/docs/api/composition/transform.md new file mode 100644 index 00000000..51b89e14 --- /dev/null +++ b/docs/api/composition/transform.md @@ -0,0 +1,5 @@ +# Transform & Transformation + +Affine transformations of surface brightness profiles (shear, shift, rotation, flux scaling). + +::: jax_galsim.transform.Transformation diff --git a/docs/api/config/errors.md b/docs/api/config/errors.md new file mode 100644 index 00000000..3105921d --- /dev/null +++ b/docs/api/config/errors.md @@ -0,0 +1,5 @@ +# Errors & Warnings + +Exception and warning classes for JAX-GalSim error handling. + +::: jax_galsim.errors diff --git a/docs/api/config/gsparams.md b/docs/api/config/gsparams.md new file mode 100644 index 00000000..8ea867b5 --- /dev/null +++ b/docs/api/config/gsparams.md @@ -0,0 +1,5 @@ +# GSParams + +Numerical configuration parameters controlling accuracy and performance trade-offs. + +::: jax_galsim.gsparams.GSParams diff --git a/docs/api/config/utilities.md b/docs/api/config/utilities.md new file mode 100644 index 00000000..009851cf --- /dev/null +++ b/docs/api/config/utilities.md @@ -0,0 +1,5 @@ +# Utilities + +General utility functions. + +::: jax_galsim.utilities diff --git a/docs/api/coordinates/angle.md b/docs/api/coordinates/angle.md new file mode 100644 index 00000000..b49d6e74 --- /dev/null +++ b/docs/api/coordinates/angle.md @@ -0,0 +1,7 @@ +# Angle & AngleUnit + +Angle representation and unit conversion (radians, degrees, arcminutes, arcseconds, hours). + +::: jax_galsim.angle.Angle + +::: jax_galsim.angle.AngleUnit diff --git a/docs/api/coordinates/bounds.md b/docs/api/coordinates/bounds.md new file mode 100644 index 00000000..8bd03672 --- /dev/null +++ b/docs/api/coordinates/bounds.md @@ -0,0 +1,7 @@ +# Bounds + +Rectangular bounding box types for real-valued (`BoundsD`) and integer (`BoundsI`) coordinates. + +::: jax_galsim.bounds.BoundsD + +::: jax_galsim.bounds.BoundsI diff --git a/docs/api/coordinates/celestial.md b/docs/api/coordinates/celestial.md new file mode 100644 index 00000000..610ca0ad --- /dev/null +++ b/docs/api/coordinates/celestial.md @@ -0,0 +1,5 @@ +# CelestialCoord + +Celestial coordinate (RA, Dec) representation and operations. + +::: jax_galsim.celestial.CelestialCoord diff --git a/docs/api/coordinates/position.md b/docs/api/coordinates/position.md new file mode 100644 index 00000000..28510902 --- /dev/null +++ b/docs/api/coordinates/position.md @@ -0,0 +1,7 @@ +# Position + +2D position types for real-valued (`PositionD`) and integer (`PositionI`) coordinates. + +::: jax_galsim.position.PositionD + +::: jax_galsim.position.PositionI diff --git a/docs/api/coordinates/shear.md b/docs/api/coordinates/shear.md new file mode 100644 index 00000000..75f9442a --- /dev/null +++ b/docs/api/coordinates/shear.md @@ -0,0 +1,5 @@ +# Shear + +Gravitational shear representation with multiple parametrizations (g1/g2, e1/e2, eta1/eta2). + +::: jax_galsim.shear.Shear diff --git a/docs/api/core/draw.md b/docs/api/core/draw.md new file mode 100644 index 00000000..2897eb20 --- /dev/null +++ b/docs/api/core/draw.md @@ -0,0 +1,5 @@ +# Core: Drawing + +Internal drawing utilities for rendering profiles to pixel grids. + +::: jax_galsim.core.draw diff --git a/docs/api/core/interpolate.md b/docs/api/core/interpolate.md new file mode 100644 index 00000000..e400e8db --- /dev/null +++ b/docs/api/core/interpolate.md @@ -0,0 +1,5 @@ +# Core: Interpolation + +Internal interpolation utilities (Akima splines, coefficient computation). + +::: jax_galsim.core.interpolate diff --git a/docs/api/core/math.md b/docs/api/core/math.md new file mode 100644 index 00000000..a2188f04 --- /dev/null +++ b/docs/api/core/math.md @@ -0,0 +1,5 @@ +# Core: Math + +Gradient-safe mathematical utilities (`safe_sqrt`, etc.). + +::: jax_galsim.core.math diff --git a/docs/api/core/utils.md b/docs/api/core/utils.md new file mode 100644 index 00000000..6d2a4ff9 --- /dev/null +++ b/docs/api/core/utils.md @@ -0,0 +1,5 @@ +# Core: Utilities + +Core utilities: `@implements` decorator, `has_tracers()`, type casting helpers. + +::: jax_galsim.core.utils diff --git a/docs/api/image.md b/docs/api/image.md new file mode 100644 index 00000000..c965b13d --- /dev/null +++ b/docs/api/image.md @@ -0,0 +1,5 @@ +# Image + +Immutable JAX array wrapper with WCS and bounds metadata. + +::: jax_galsim.image.Image diff --git a/docs/api/interpolation/interpolant.md b/docs/api/interpolation/interpolant.md new file mode 100644 index 00000000..23bcd68d --- /dev/null +++ b/docs/api/interpolation/interpolant.md @@ -0,0 +1,17 @@ +# Interpolants + +Interpolation kernels for image resampling. + +::: jax_galsim.interpolant.Interpolant + +::: jax_galsim.interpolant.Nearest + +::: jax_galsim.interpolant.SincInterpolant + +::: jax_galsim.interpolant.Linear + +::: jax_galsim.interpolant.Cubic + +::: jax_galsim.interpolant.Quintic + +::: jax_galsim.interpolant.Lanczos diff --git a/docs/api/interpolation/interpolatedimage.md b/docs/api/interpolation/interpolatedimage.md new file mode 100644 index 00000000..0cb4b52e --- /dev/null +++ b/docs/api/interpolation/interpolatedimage.md @@ -0,0 +1,5 @@ +# InterpolatedImage + +Surface brightness profile defined by interpolation over a given image. + +::: jax_galsim.interpolatedimage.InterpolatedImage diff --git a/docs/api/math/bessel.md b/docs/api/math/bessel.md new file mode 100644 index 00000000..8b59134f --- /dev/null +++ b/docs/api/math/bessel.md @@ -0,0 +1,5 @@ +# Bessel Functions + +Bessel and related special functions. + +::: jax_galsim.bessel diff --git a/docs/api/math/integ.md b/docs/api/math/integ.md new file mode 100644 index 00000000..50278f00 --- /dev/null +++ b/docs/api/math/integ.md @@ -0,0 +1,5 @@ +# Integration + +Numerical integration utilities. + +::: jax_galsim.integ diff --git a/docs/api/noise/noise.md b/docs/api/noise/noise.md new file mode 100644 index 00000000..d15ad5fb --- /dev/null +++ b/docs/api/noise/noise.md @@ -0,0 +1,15 @@ +# Noise Models + +Noise classes for adding realistic noise to images. + +::: jax_galsim.noise.BaseNoise + +::: jax_galsim.noise.GaussianNoise + +::: jax_galsim.noise.PoissonNoise + +::: jax_galsim.noise.CCDNoise + +::: jax_galsim.noise.DeviateNoise + +::: jax_galsim.noise.VariableGaussianNoise diff --git a/docs/api/noise/random.md b/docs/api/noise/random.md new file mode 100644 index 00000000..88042531 --- /dev/null +++ b/docs/api/noise/random.md @@ -0,0 +1,19 @@ +# Random Deviates + +Random number generators. + +::: jax_galsim.random.BaseDeviate + +::: jax_galsim.random.UniformDeviate + +::: jax_galsim.random.GaussianDeviate + +::: jax_galsim.random.PoissonDeviate + +::: jax_galsim.random.Chi2Deviate + +::: jax_galsim.random.GammaDeviate + +::: jax_galsim.random.WeibullDeviate + +::: jax_galsim.random.BinomialDeviate diff --git a/docs/api/photons/photon_array.md b/docs/api/photons/photon_array.md new file mode 100644 index 00000000..30d34dfc --- /dev/null +++ b/docs/api/photons/photon_array.md @@ -0,0 +1,5 @@ +# PhotonArray + +Array of photon positions, fluxes, and other properties for photon shooting. + +::: jax_galsim.photon_array.PhotonArray diff --git a/docs/api/photons/sensor.md b/docs/api/photons/sensor.md new file mode 100644 index 00000000..6c45d17d --- /dev/null +++ b/docs/api/photons/sensor.md @@ -0,0 +1,5 @@ +# Sensor + +Sensor model for converting photons to pixel counts. + +::: jax_galsim.sensor.Sensor diff --git a/docs/api/profiles/box.md b/docs/api/profiles/box.md new file mode 100644 index 00000000..cecaa943 --- /dev/null +++ b/docs/api/profiles/box.md @@ -0,0 +1,7 @@ +# Box & Pixel + +Box (uniform rectangular) and Pixel (unit-width box) surface brightness profiles. + +::: jax_galsim.box.Box + +::: jax_galsim.box.Pixel diff --git a/docs/api/profiles/deltafunction.md b/docs/api/profiles/deltafunction.md new file mode 100644 index 00000000..69ce9f05 --- /dev/null +++ b/docs/api/profiles/deltafunction.md @@ -0,0 +1,5 @@ +# DeltaFunction + +Delta function (point source) surface brightness profile. + +::: jax_galsim.deltafunction.DeltaFunction diff --git a/docs/api/profiles/exponential.md b/docs/api/profiles/exponential.md new file mode 100644 index 00000000..9e93cdde --- /dev/null +++ b/docs/api/profiles/exponential.md @@ -0,0 +1,5 @@ +# Exponential + +Exponential surface brightness profile, commonly used for galaxy disk components. + +::: jax_galsim.exponential.Exponential diff --git a/docs/api/profiles/gaussian.md b/docs/api/profiles/gaussian.md new file mode 100644 index 00000000..99360444 --- /dev/null +++ b/docs/api/profiles/gaussian.md @@ -0,0 +1,5 @@ +# Gaussian + +Circular or elliptical Gaussian surface brightness profile. + +::: jax_galsim.gaussian.Gaussian diff --git a/docs/api/profiles/gsobject.md b/docs/api/profiles/gsobject.md new file mode 100644 index 00000000..9c3cdda9 --- /dev/null +++ b/docs/api/profiles/gsobject.md @@ -0,0 +1,5 @@ +# GSObject + +Base class for all surface brightness profiles. + +::: jax_galsim.gsobject.GSObject diff --git a/docs/api/profiles/moffat.md b/docs/api/profiles/moffat.md new file mode 100644 index 00000000..b7c1f34d --- /dev/null +++ b/docs/api/profiles/moffat.md @@ -0,0 +1,5 @@ +# Moffat + +Moffat surface brightness profile, commonly used for PSF modeling. + +::: jax_galsim.moffat.Moffat diff --git a/docs/api/profiles/spergel.md b/docs/api/profiles/spergel.md new file mode 100644 index 00000000..13c19314 --- /dev/null +++ b/docs/api/profiles/spergel.md @@ -0,0 +1,5 @@ +# Spergel + +Spergel surface brightness profile, a flexible model for galaxy light distributions. + +::: jax_galsim.spergel.Spergel diff --git a/docs/api/wcs/fits.md b/docs/api/wcs/fits.md new file mode 100644 index 00000000..07cb05d5 --- /dev/null +++ b/docs/api/wcs/fits.md @@ -0,0 +1,5 @@ +# FITS I/O + +FITS file reading, writing, and header handling. + +::: jax_galsim.fits diff --git a/docs/api/wcs/fitswcs.md b/docs/api/wcs/fitswcs.md new file mode 100644 index 00000000..b5721c38 --- /dev/null +++ b/docs/api/wcs/fitswcs.md @@ -0,0 +1,5 @@ +# FITS WCS + +FITS-based World Coordinate Systems. + +::: jax_galsim.fitswcs.GSFitsWCS diff --git a/docs/api/wcs/wcs.md b/docs/api/wcs/wcs.md new file mode 100644 index 00000000..b9ecf8e5 --- /dev/null +++ b/docs/api/wcs/wcs.md @@ -0,0 +1,17 @@ +# WCS Base Classes + +World Coordinate System hierarchy. + +::: jax_galsim.wcs.BaseWCS + +::: jax_galsim.wcs.PixelScale + +::: jax_galsim.wcs.ShearWCS + +::: jax_galsim.wcs.JacobianWCS + +::: jax_galsim.wcs.OffsetWCS + +::: jax_galsim.wcs.OffsetShearWCS + +::: jax_galsim.wcs.AffineTransform diff --git a/docs/getting-started/index.md b/docs/getting-started/index.md new file mode 100644 index 00000000..062d9480 --- /dev/null +++ b/docs/getting-started/index.md @@ -0,0 +1,7 @@ +# Getting Started + +New to JAX-GalSim? Start here. + +- [Installation](installation.md) — Install JAX-GalSim and set up GPU support +- [Quick Start](quickstart.md) — Simulate a galaxy image in a few lines of code +- [Notable Differences](../notable-differences.md) — What changes when GalSim runs on JAX diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md new file mode 100644 index 00000000..88142dd2 --- /dev/null +++ b/docs/getting-started/installation.md @@ -0,0 +1,69 @@ +# Installation + +## Quick Install + +```bash +pip install jax-galsim +``` + +This installs JAX-GalSim and its dependencies (JAX, NumPy, GalSim, Astropy). + +## GPU Support + +JAX-GalSim inherits GPU support from JAX. To use NVIDIA GPUs, install the appropriate JAX variant: + +```bash +pip install -U "jax[cuda12]" +``` + +See the [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html) for other accelerators and platform-specific instructions. + +## Development Install + +To contribute to JAX-GalSim or run the test suite: + +```bash +# Clone with submodules (required for GalSim reference tests) +git clone --recurse-submodules https://github.com/GalSim-developers/JAX-GalSim +cd JAX-GalSim + +# Create a virtual environment +python -m venv .venv && source .venv/bin/activate + +# Install in editable mode with dev dependencies +pip install -e ".[dev]" + +# Install pre-commit hooks +pre-commit install +``` + +### Running Tests + +```bash +# Run all tests +pytest + +# Run a specific test file +pytest tests/jax/test_api.py + +# Run a specific test +pytest tests/jax/test_api.py::test_api_same + +# Verbose output with timing +pytest -vv --durations=100 +``` + +### Linting and Formatting + +```bash +# Lint +ruff check . --fix + +# Format +ruff format . + +# Or run both via pre-commit +pre-commit run --all-files +``` + +See [CONTRIBUTING.md](https://github.com/GalSim-developers/JAX-GalSim/blob/main/CONTRIBUTING.md) for full contribution guidelines. diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md new file mode 100644 index 00000000..d51f3f60 --- /dev/null +++ b/docs/getting-started/quickstart.md @@ -0,0 +1,98 @@ +# Quick Start + +A complete galaxy image simulation, then JAX transformations (`jit`, `grad`, `vmap`) on top. + +## A Simple Simulation + +A Gaussian galaxy convolved with a Gaussian PSF, drawn and noised -- equivalent to GalSim's `demo1.py`. + +```python +import jax_galsim + +# Galaxy parameters +gal_flux = 1e5 # total counts +gal_sigma = 2.0 # arcsec +psf_sigma = 1.0 # arcsec +pixel_scale = 0.2 # arcsec/pixel +noise_sigma = 30.0 # counts per pixel + +# Define profiles +gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) +psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) + +# Convolve galaxy with PSF +final = jax_galsim.Convolve([gal, psf]) + +# Draw the image +image = final.drawImage(scale=pixel_scale) + +# Add Gaussian noise +image = image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma)) + +# Write to FITS +image.write("output/demo1.fits") +``` + +Most GalSim code translates directly by replacing `import galsim` with `import jax_galsim`. + +## JIT Compilation + +Wrap your simulation in `jax.jit` to compile it into an optimized XLA computation: + +```python +import jax + +@jax.jit +def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(scale=0.2) + +# First call compiles; subsequent calls are fast +image = simulate(1e5, 2.0) +``` + +## Automatic Differentiation + +Compute gradients of any scalar output with respect to parameters: + +```python +def total_flux(gal_sigma, psf_sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=gal_sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) + final = jax_galsim.Convolve([gal, psf]) + image = final.drawImage(scale=0.2) + return image.array.sum() + +# Gradient of total image flux with respect to both sigmas +grad_fn = jax.grad(total_flux, argnums=(0, 1)) +d_gal, d_psf = grad_fn(2.0, 1.0) +``` + +Useful for fitting galaxy models to data via gradient-based optimization. + +## Vectorization with vmap + +Batch-simulate galaxies with different parameters without explicit loops: + +```python +import jax.numpy as jnp + +sigmas = jnp.linspace(1.0, 4.0, 10) + +@jax.vmap +def batch_simulate(sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(scale=0.2, nx=64, ny=64).array + +# Simulate all 10 galaxies in parallel +images = batch_simulate(sigmas) # shape: (10, 64, 64) +``` + +## Next Steps + +- [Notable Differences](../notable-differences.md) — What changes when GalSim runs on JAX +- [API Reference](../api/profiles/gaussian.md) — Full API documentation diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..6796a026 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,80 @@ +# JAX-GalSim + +**JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.** + +[![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) + +!!! warning "Early Development" + + This project is still in an early development phase. Please use the + [reference GalSim implementation](https://github.com/GalSim-developers/GalSim) + for any scientific applications. + +--- + +## Why JAX-GalSim? + +JAX-GalSim reimplements [GalSim](https://github.com/GalSim-developers/GalSim) in pure JAX, unlocking: + +!!! tip "JIT Compilation" + + Compile simulation pipelines with `jax.jit` for significant speedups, especially on GPU. + +!!! tip "Automatic Differentiation" + + Compute gradients of simulation outputs with respect to galaxy parameters using `jax.grad`. + +!!! tip "Vectorization" + + Batch simulations over parameter grids with `jax.vmap` --- no explicit loops needed. + +--- + +## Quick Install + +```bash +pip install jax-galsim +``` + +## Minimal Example + +```python +import jax +import jax_galsim + +# Define a galaxy and PSF +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + +# Convolve and draw +final = jax_galsim.Convolve([gal, psf]) +image = final.drawImage(scale=0.2) + +# Add noise +image = image.addNoise(jax_galsim.GaussianNoise(sigma=30.0)) +``` + +JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline: + +```python +@jax.jit +def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + return jax_galsim.Convolve([gal, psf]).drawImage(scale=0.2).array.sum() + +# Compute gradients with respect to galaxy parameters +grad_fn = jax.grad(simulate, argnums=(0, 1)) +dflux, dsigma = grad_fn(1e5, 2.0) +``` + +--- + +## Next Steps + +- [Installation](getting-started/installation.md) --- Set up JAX-GalSim with GPU support +- [Quick Start](getting-started/quickstart.md) --- Walk through a complete simulation +- [Notable Differences](notable-differences.md) --- What changes when GalSim runs on JAX +- [API Reference](api/profiles/gaussian.md) --- Browse the full API diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js new file mode 100644 index 00000000..117b0460 --- /dev/null +++ b/docs/javascripts/mathjax.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true, + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex", + }, +}; + +document$.subscribe(() => { + MathJax.typesetPromise(); +}); diff --git a/docs/notable-differences.md b/docs/notable-differences.md new file mode 100644 index 00000000..8d8b38d6 --- /dev/null +++ b/docs/notable-differences.md @@ -0,0 +1,228 @@ +# Notable Differences from GalSim + +JAX-GalSim is designed as a drop-in replacement for GalSim --- replacing +`import galsim` with `import jax_galsim` works for all supported features. +However, JAX's execution model introduces several fundamental differences +that you should understand before porting code or writing new simulations. + +--- + +## Immutability + +JAX arrays are **immutable**. Any GalSim operation that modifies data in-place +returns a new object in JAX-GalSim instead. + +```python +# GalSim — mutates the image in-place +image.addNoise(noise) +image.array[10, 10] = 0.0 + +# JAX-GalSim — returns a new image each time +image = image.addNoise(noise) + +# Direct array element mutation is not supported. +# Use jax.numpy operations to produce a new array: +new_array = image.array.at[10, 10].set(0.0) +``` + +This is the most common change when porting GalSim code. Every call that +modifies an image, adds noise, or updates a value must capture the return value. +If you forget the assignment, the original object is unchanged and no error is +raised --- a subtle source of bugs. + +--- + +## Array Views + +NumPy supports **array views** --- slices that share memory with the original +array. JAX does not. In GalSim, you can obtain a real-valued view of a complex +image (e.g., the real part shares memory with the underlying complex buffer). +In JAX-GalSim, these operations return **copies** instead. Modifying the copy +does not affect the original. + +```python +# GalSim — real_part is a view, shares memory with complex_image +real_part = complex_image.real + +# JAX-GalSim — real_part is a copy +real_part = complex_image.real # independent array +``` + +--- + +## Random Number Generation + +JAX uses a **functional PRNG** --- random state is explicit and must be passed +through computations. This has several consequences: + +**Determinism**: Given the same seed, JAX-GalSim produces identical results +across runs and platforms (CPU, GPU, TPU). GalSim's results may vary by platform. + +**Explicit state**: Random deviates carry their state explicitly. Under the hood, +JAX-GalSim wraps JAX's key-based PRNG in GalSim's familiar noise API, so the +user-facing interface looks the same: + +```python +noise = jax_galsim.GaussianNoise(sigma=30.0) +image = image.addNoise(noise) # state is managed internally +``` + +**Different sequences**: Even with the same seed value, the actual random number +sequences differ from GalSim. Results will not match GalSim number-for-number. +This is expected --- the underlying PRNG algorithms are completely different. + +**No in-place fill**: GalSim deviates can "fill" existing arrays. JAX deviates +always return new arrays, consistent with JAX's immutability model. + +--- + +## PyTree Registration + +All JAX-GalSim objects are registered as JAX **PyTrees**. This is what allows +you to pass them directly to `jax.jit`, `jax.grad`, and `jax.vmap`. + +A PyTree splits each object into two parts: + +| Part | What it contains | Examples | Effect of changing | +|------|-----------------|----------|--------------------| +| **Children** (traced) | Values JAX differentiates through | `flux`, `sigma`, `half_light_radius` | Re-evaluation, not recompilation | +| **Auxiliary data** (static) | Structure and configuration | `GSParams`, enum flags | Full recompilation under `jit` | + +In practice, profile parameters live in a `_params` dict (children) and +numerical configuration lives in `_gsparams` (auxiliary): + +```python +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +# gal._params = {"flux": 1e5, "sigma": 2.0} — traced by JAX +# gal._gsparams = GSParams(...) — static, triggers recompile +``` + +Because `GSParams` is static auxiliary data, changing it between calls to a +`jit`-compiled function triggers recompilation. Keep `GSParams` constant across +calls when possible. + +```python +import jax + +gsparams = jax_galsim.GSParams(maximum_fft_size=8192) + +@jax.jit +def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma, gsparams=gsparams) + return gal.drawImage(scale=0.2).array.sum() + +# Changing gsparams here would cause recompilation on next call +``` + +--- + +## Control Flow and Tracing + +JAX's JIT compiler works by **tracing** --- it records operations on abstract +values to build a computation graph. This restricts what Python code can do +inside `jit`-compiled functions. + +### No branching on traced values + +You cannot use Python `if`/`else` on values that JAX is tracing (e.g., profile +parameters passed into a `jit`-compiled function): + +```python +@jax.jit +def bad(sigma): + if sigma > 1.0: # ERROR: sigma is a tracer, not a concrete value + return sigma * 2 + return sigma + +@jax.jit +def good(sigma): + return jax.lax.cond(sigma > 1.0, lambda s: s * 2, lambda s: s, sigma) +``` + +JAX-GalSim uses an internal `has_tracers()` utility to detect tracing and +avoid problematic control flow in its own implementations. + +### Fixed output shapes + +Under `jit`, the **shape** of every array must be determinable at compile time. +Operations whose output size depends on input values (e.g., adaptive image +sizing based on a traced parameter) may not work. When using `jax.vmap`, you +must specify fixed image dimensions: + +```python +@jax.vmap +def batch(sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) + # Must specify nx, ny so all images have the same shape + return gal.drawImage(scale=0.2, nx=64, ny=64).array +``` + +### The `__init__` gotcha + +During `jit` tracing, JAX calls constructors with **tracer objects** rather than +concrete Python numbers. Type checks like `isinstance(sigma, float)` will fail +on tracers. JAX-GalSim handles this internally, but if you subclass any +JAX-GalSim object, be aware that `__init__` may receive tracers: + +```python +from jax_galsim.core.utils import has_tracers + +class MyProfile(jax_galsim.GSObject): + def __init__(self, sigma, gsparams=None): + if not has_tracers(sigma): + # Only validate with concrete values + if sigma <= 0: + raise ValueError("sigma must be positive") + ... +``` + +--- + +## Profile Restrictions + +Some GalSim features are not yet implemented in JAX-GalSim: + +- **Truncated Moffat profiles**: The `trunc` parameter is not supported. +- **ChromaticObject**: All chromatic functionality (wavelength-dependent profiles) is not available. +- **InterpolatedKImage**: Not implemented. +- **Airy, Kolmogorov, OpticalPSF, RealGalaxy**: And other profiles --- see [API Coverage](api-coverage.md) for the full list. + +The project currently implements **22.5%** of the GalSim public API, focused on +the most commonly used profiles and operations. Coverage is expanding. + +--- + +## Numerical Precision + +Simulation results may differ slightly from GalSim at the floating-point level: + +- **Operation reordering**: JAX (via XLA) may reorder floating-point operations for performance. Floating-point addition is not associative, so different orderings produce slightly different results. +- **Different math kernels**: XLA-compiled math kernels may differ from system math libraries (e.g., `libm`) that GalSim uses via NumPy/C++. +- **Gradient-safe functions**: JAX-GalSim uses special implementations (e.g., `safe_sqrt` that avoids `NaN` gradients at zero) where GalSim uses standard library functions. These may produce slightly different values at edge cases. +- **Default precision**: JAX defaults to 32-bit floats. Enable 64-bit with `jax.config.update("jax_enable_x64", True)` for higher precision matching GalSim's default behavior. + +These differences are typically at the level of floating-point round-off +($\sim 10^{-7}$ for float32, $\sim 10^{-15}$ for float64) and should not +affect scientific conclusions. + +--- + +## The `@implements` Decorator + +JAX-GalSim reuses GalSim's docstrings rather than duplicating them. Every public +class and function uses an `@implements` decorator that copies the docstring from +the corresponding GalSim object and appends a note about JAX-specific differences: + +```python +from jax_galsim.core.utils import implements +import galsim as _galsim + +@implements(_galsim.Gaussian, + lax_description="LAX: Does not support ChromaticObject.") +class Gaussian(GSObject): + ... +``` + +This means the [API Reference](api/profiles/gaussian.md) shows GalSim's +documentation with an added "LAX-backend" note. If you see RST-formatted cross-references +like `:func:` or `:class:` in the docs, they come from GalSim's original docstrings. diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..c6cbdd79 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,124 @@ +site_name: JAX-GalSim +site_description: JAX port of GalSim for GPU-accelerated, differentiable galaxy image simulations +site_url: https://galsim-developers.github.io/JAX-GalSim/ +repo_url: https://github.com/GalSim-developers/JAX-GalSim +repo_name: GalSim-developers/JAX-GalSim + +theme: + name: material + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: indigo + accent: indigo + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: indigo + accent: indigo + toggle: + icon: material/brightness-4 + name: Switch to light mode + features: + - content.code.copy + - navigation.sections + - navigation.expand + - navigation.top + - navigation.indexes + - search.highlight + - toc.follow + - toc.integrate + +plugins: + - search + - mkdocstrings: + handlers: + python: + options: + docstring_style: numpy + show_source: true + merge_init_into_class: true + members_order: source + show_root_heading: true + show_root_full_path: false + show_if_no_docstring: true + show_symbol_type_heading: true + show_symbol_type_toc: true + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.tabbed: + alternate_style: true + - pymdownx.arithmatex: + generic: true + - toc: + permalink: true + +extra_javascript: + - javascripts/mathjax.js + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + +nav: + - Home: index.md + - Getting Started: + - getting-started/index.md + - Installation: getting-started/installation.md + - Quick Start: getting-started/quickstart.md + - Notable Differences: notable-differences.md + - API Coverage: api-coverage.md + - API Reference: + - Profiles: + - api/profiles/gsobject.md + - api/profiles/gaussian.md + - api/profiles/moffat.md + - api/profiles/spergel.md + - api/profiles/exponential.md + - api/profiles/deltafunction.md + - api/profiles/box.md + - Composition: + - api/composition/convolve.md + - api/composition/sum.md + - api/composition/transform.md + - Image: api/image.md + - Coordinates: + - api/coordinates/position.md + - api/coordinates/bounds.md + - api/coordinates/angle.md + - api/coordinates/shear.md + - api/coordinates/celestial.md + - WCS: + - api/wcs/wcs.md + - api/wcs/fitswcs.md + - api/wcs/fits.md + - Noise & Random: + - api/noise/random.md + - api/noise/noise.md + - Interpolation: + - api/interpolation/interpolant.md + - api/interpolation/interpolatedimage.md + - Photon Shooting: + - api/photons/photon_array.md + - api/photons/sensor.md + - Configuration: + - api/config/gsparams.md + - api/config/errors.md + - api/config/utilities.md + - Math: + - api/math/bessel.md + - api/math/integ.md + - Core Internals: + - api/core/draw.md + - api/core/interpolate.md + - api/core/math.md + - api/core/utils.md diff --git a/pyproject.toml b/pyproject.toml index 3a726884..7cf120fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,14 @@ dependencies = [ [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] +docs = [ + "mkdocs-material>=9.0", + "mkdocstrings[python]>=0.24", +] [project.urls] home = "https://github.com/GalSim-developers/JAX-GalSim" +documentation = "https://galsim-developers.github.io/JAX-GalSim/" [tool.setuptools.packages.find] include = ["jax_galsim*"]