Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 6 additions & 88 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ repos:
- id: trailing-whitespace
- id: no-commit-to-branch
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.10
rev: v0.14.11
hooks:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
args: [--fix]
- id: ruff-check
args: [--preview, --select=CPY]
- id: ruff-format
- repo: https://github.com/allganize/ty-pre-commit
rev: v0.0.11
hooks:
- id: ty-check
- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.11.1
hooks:
Expand All @@ -21,89 +25,3 @@ repos:
rev: v2.3.10
hooks:
- id: biome-format
- repo: https://github.com/H4rryK4ne/update-mypy-hook
rev: v0.3.0
hooks:
- id: update-mypy-hook
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.1
hooks:
- id: mypy
args: [--config-file=pyproject.toml, .]
pass_filenames: false
language_version: '3.13'
additional_dependencies:
- alabaster==1.0.0
- anndata==0.12.6
- array-api-compat==1.12.0
- babel==2.17.0
- certifi==2025.11.12
- cffi==2.0.0
- charset-normalizer==3.4.4
- click==8.3.1
- cloudpickle==3.1.2
- colorama==0.4.6 ; sys_platform == 'win32'
- coverage==7.13.0
- dask==2025.11.0
- docutils==0.22.3
- donfig==0.8.1.post1
- execnet==2.1.2
- fsspec==2025.12.0
- google-crc32c==1.7.1
- h5py==3.15.1
- idna==3.11
- imagesize==1.4.1
- iniconfig==2.3.0
- jinja2==3.1.6
- joblib==1.5.2
- legacy-api-wrap==1.5
- llvmlite==0.46.0
- locket==1.0.0
- markdown-it-py==4.0.0
- markupsafe==3.0.3
- mdurl==0.1.2
- natsort==8.4.0
- numba==0.63.1
- numcodecs==0.16.5
- numpy==2.3.5
- numpy-typing-compat==20251206.2.3
- optype==0.15.0
- packaging==25.0
- pandas==2.3.3
- partd==1.4.2
- pluggy==1.6.0
- pycparser==2.23 ; implementation_name != 'PyPy'
- pygments==2.19.2
- pytest==9.0.2
- pytest-codspeed==4.2.0
- pytest-doctestplus==1.6.0
- pytest-xdist==3.8.0
- python-dateutil==2.9.0.post0
- pytz==2025.2
- pyyaml==6.0.3
- requests==2.32.5
- rich==14.2.0
- roman-numerals==3.1.0
- scikit-learn==1.8.0
- scipy==1.16.3
- scipy-stubs==1.16.3.3
- six==1.17.0
- snowballstemmer==3.0.1
- sphinx==9.0.4
- sphinxcontrib-applehelp==2.0.0
- sphinxcontrib-devhelp==2.0.0
- sphinxcontrib-htmlhelp==2.1.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==2.0.0
- sphinxcontrib-serializinghtml==2.0.0
- threadpoolctl==3.6.0
- toolz==1.1.0
- types-docutils==0.22.3.20251115
- tzdata==2025.2
- urllib3==2.6.1
- zarr==3.1.5
ci:
autoupdate_commit_msg: 'ci: pre-commit autoupdate'
skip:
- mypy # too big
- update-mypy-hook # offline?
8 changes: 8 additions & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"recommendations": [
"astral-sh.ty",
"biomejs.biome",
"tamasfe.even-better-toml",
"charliermarsh.ruff",
],
}
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
},
"python.testing.pytestArgs": ["-vv", "--color=yes", "-m", "not benchmark"],
"python.testing.pytestEnabled": true,
"python.languageServer": "None", // ty instead
}
30 changes: 15 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,9 @@ doc = [
"sphinx>=9.0.1",
"sphinx-autofixture>=0.4.1",
]
# for update-mypy-hook
mypy = [
"fast-array-utils[full]",
typing = [
"scipy-stubs",
# TODO: replace sphinx with this: { include-group = "doc" },
"sphinx",
"types-docutils",
{ include-group = "test" },
]

[tool.hatch.version]
Expand All @@ -84,15 +79,21 @@ packages = [ "src/testing", "src/fast_array_utils" ]
[tool.hatch.envs.default]
installer = "uv"

[tool.hatch.envs.typecheck]
dependencies = [ "ty" ]
features = [ "full" ]
dependency-groups = [ "test", "doc", "typing" ]
scripts.run = "ty check"

[tool.hatch.envs.docs]
dependency-groups = [ "doc" ]
dependency-groups = [ "doc", "typing" ]
scripts.build = "sphinx-build -M html docs docs/_build"
scripts.clean = "git clean -fdX docs"
scripts.open = "python -m webbrowser -t docs/_build/html/index.html"

[tool.hatch.envs.hatch-test]
default-args = [ ]
dependency-groups = [ "test-min" ]
dependency-groups = [ "test-min", "typing" ]
# TODO: remove scipy once https://github.com/pypa/hatch/pull/2127 is released
extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ]
env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
Expand Down Expand Up @@ -161,6 +162,9 @@ lint.isort.lines-after-imports = 2
lint.pydocstyle.convention = "numpy"
lint.future-annotations = true

# Override the project-wide Python version for a developer scripts directory:
per-file-target-version."**/*.pyi" = "py313"

[tool.pytest]
strict = true
addopts = [
Expand Down Expand Up @@ -194,13 +198,9 @@ html.directory = "test-data/htmlcov"
run.omit = [ "src/testing/*", "tests/*" ]
report.exclude_also = [ "if TYPE_CHECKING:", "@numba[.]njit", "[.]{3}" ]

[tool.mypy]
strict = true
# https://github.com/dask/dask/issues/8853
implicit_reexport = true
explicit_package_bases = true
mypy_path = [ "$MYPY_CONFIG_FILE_DIR/typings", "$MYPY_CONFIG_FILE_DIR/src" ]

[tool.pyright]
stubPath = "./typings"
reportPrivateUsage = false

[tool.ty.environment]
extra-paths = [ "typings" ]
2 changes: 1 addition & 1 deletion src/fast_array_utils/_plugins/numba_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def overload_sparse_copy(inst: CSType) -> None | Callable[[CSType], CSType]:

# nopython code:
def copy(inst: CSType) -> CSType: # pragma: no cover
return _sparse_copy(inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape) # type: ignore[return-value]
return _sparse_copy(inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape)

return copy

Expand Down
8 changes: 4 additions & 4 deletions src/fast_array_utils/conv/_to_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import warnings
from functools import partial, singledispatch
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -47,7 +47,7 @@ def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"]
msg = f"{order=!r} will probably be ignored: Dask can not be made to emit F-contiguous arrays reliably."
warnings.warn(msg, RuntimeWarning, stacklevel=4)
x = x.map_blocks(partial(to_dense, order=order, to_cpu_memory=to_cpu_memory))
return x.compute() if to_cpu_memory else x # type: ignore[return-value]
return x.compute() if to_cpu_memory else x


@to_dense_.register(types.CSDataset)
Expand All @@ -58,7 +58,7 @@ def _to_dense_ooc(x: types.CSDataset, /, *, order: Literal["K", "A", "C", "F"] =
msg = "to_cpu_memory must be True if x is an CS{R,C}Dataset"
raise ValueError(msg)
# TODO(flying-sheep): why is to_memory of type Any? # noqa: TD003
return to_dense(cast("types.CSBase", x.to_memory()), order=sparse_order(x, order=order))
return to_dense(x.to_memory(), order=sparse_order(x, order=order))


@to_dense_.register(types.CupyArray | types.CupySpMatrix)
Expand All @@ -77,4 +77,4 @@ def sparse_order(x: types.spmatrix | types.sparray | types.CupySpMatrix | types.

if order in {"K", "A"}:
order = "F" if x.format == "csc" else "C"
return cast("Literal['C', 'F']", order)
return order
22 changes: 16 additions & 6 deletions src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ._generic_ops import Ops
from ._typing import NoDtypeOps, StatFunDtype, StatFunNoDtype


__all__ = ["is_constant", "max", "mean", "mean_var", "min", "sum"]


Expand Down Expand Up @@ -87,9 +86,9 @@ def mean(x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTy
@overload
def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> NDArray[np.number[Any]]: ...
@overload
def mean(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> types.CupyArray: ...
def mean(x: GpuArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None) -> types.CupyArray: ...
@overload
def mean(x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None) -> types.DaskArray: ...
def mean(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: ToDType[Any] | None = None) -> types.DaskArray: ...


def mean(
Expand Down Expand Up @@ -144,7 +143,17 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup
@overload
def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ...
@overload
def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ...
def mean_var(
x: types.DaskArray[CpuArray | GpuArray], /, *, axis: None = None, correction: int = 0
) -> tuple[types.DaskArray[np.float64], types.DaskArray[np.float64]]: ...
@overload
def mean_var(
x: types.DaskArray[CpuArray], /, *, axis: Literal[0, 1], correction: int = 0
) -> tuple[types.DaskArray[NDArray[np.float64]], types.DaskArray[NDArray[np.float64]]]: ...
@overload
def mean_var(
x: types.DaskArray[GpuArray], /, *, axis: Literal[0, 1], correction: int = 0
) -> tuple[types.DaskArray[types.CupyArray], types.DaskArray[types.CupyArray]]: ...


def mean_var(
Expand All @@ -157,7 +166,8 @@ def mean_var(
tuple[np.float64, np.float64]
| tuple[NDArray[np.float64], NDArray[np.float64]]
| tuple[types.CupyArray, types.CupyArray]
| tuple[types.DaskArray, types.DaskArray]
| tuple[types.DaskArray[np.float64], types.DaskArray[np.float64]]
| tuple[types.DaskArray[NDArray[np.float64]], types.DaskArray[NDArray[np.float64]]]
):
"""Mean and variance over both or one axis.

Expand Down Expand Up @@ -201,7 +211,7 @@ def mean_var(
from ._mean_var import mean_var_

validate_axis(x.ndim, axis)
return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return]
return mean_var_(x, axis=axis, correction=correction)


@overload
Expand Down
2 changes: 1 addition & 1 deletion src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _generic_op_cs(
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
assert isinstance(dtype, np.dtype | type | None)
# convert to array so dimensions collapse as expected
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op))
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv
Expand Down
2 changes: 1 addition & 1 deletion src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _is_constant_dask(a: types.DaskArray, /, *, axis: Literal[0, 1] | None = Non
(a == a[0, 0].compute()).all()
if isinstance(a._meta, np.ndarray) # noqa: SLF001
else da.map_overlap(
lambda a: np.array([[is_constant(a)]]), # type: ignore[arg-type]
lambda a: np.array([[is_constant(a)]]),
a,
# use asymmetric overlaps to avoid unnecessary computation
depth=dict.fromkeys(range(a.ndim), (0, 1)),
Expand Down
4 changes: 2 additions & 2 deletions src/fast_array_utils/stats/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def mean_(
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type]
total = sum(x, axis=axis, dtype=dtype)
n = np.prod(x.shape) if axis is None else x.shape[axis]
return total / n # type: ignore[no-any-return]
return total / n
2 changes: 1 addition & 1 deletion src/fast_array_utils/stats/_mean_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def mean_var_(
tuple[NDArray[np.float64], NDArray[np.float64]]
| tuple[types.CupyArray, types.CupyArray]
| tuple[np.float64, np.float64]
| tuple[types.DaskArray, types.DaskArray]
| tuple[types.DaskArray[NDArray[np.float64]], types.DaskArray[NDArray[np.float64]]]
):
from . import mean

Expand Down
8 changes: 4 additions & 4 deletions src/fast_array_utils/stats/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@
def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
"""Take array or matrix to a power."""
# This wrapper is necessary because TypeVars can’t be used in `singledispatch` functions
return _power(x, n, dtype=dtype) # type: ignore[return-value]
return _power(x, n, dtype=dtype)


@singledispatch
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
if TYPE_CHECKING:
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)
return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]
return x**n if dtype is None else np.power(x, n, dtype=dtype)


@_power.register(types.CSBase | types.CupyCSMatrix)
def _power_cs[Mat: types.CSBase | types.CupyCSMatrix](x: Mat, n: int, /, dtype: DTypeLike | None = None) -> Mat:
new_data = power(x.data, n, dtype=dtype)
return type(x)((new_data, x.indices, x.indptr), shape=x.shape, dtype=new_data.dtype) # type: ignore[call-overload,return-value]
return type(x)((new_data, x.indices, x.indptr), shape=x.shape, dtype=new_data.dtype)


@_power.register(types.DaskArray)
def _power_dask(x: types.DaskArray, n: int, /, dtype: DTypeLike | None = None) -> types.DaskArray:
meta = x._meta.astype(dtype or x.dtype) # noqa: SLF001
return x.map_blocks(lambda c: power(c, n, dtype=dtype), dtype=dtype, meta=meta) # type: ignore[type-var,arg-type]
return x.map_blocks(lambda c: power(c, n, dtype=dtype), dtype=dtype, meta=meta)
8 changes: 4 additions & 4 deletions src/fast_array_utils/stats/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _dask_inner(x: types.DaskArray, op: Ops, /, *, axis: Literal[0, 1] | None, d
def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]:
if isinstance(a, types.CupyArray):
a = a.get()
return a.reshape(())[()] # type: ignore[return-value]
return a.reshape(())[()]

return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type]
return rv.map_blocks(to_scalar, meta=x.dtype.type(0))


def _dask_block(
Expand All @@ -75,7 +75,7 @@ def _dask_block(
fns = {fn.__name__: fn for fn in (min, max, sum)}

axis = _normalize_axis(axis, a.ndim)
rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload]
rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op))
shape = _get_shape(rv, axis=axis, keepdims=keepdims)
return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape))

Expand All @@ -90,7 +90,7 @@ def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1] | None:
case (0, 1) | (1, 0):
axis = None
case _: # pragma: no cover
raise AxisError(axis, ndim) # type: ignore[call-overload]
raise AxisError(axis, ndim)
if axis == 0 and ndim == 1:
return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays
return axis
Expand Down
2 changes: 1 addition & 1 deletion src/fast_array_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@


if TYPE_CHECKING:
from anndata.abc import CSCDataset, CSRDataset # type: ignore[import-untyped]
from anndata.abc import CSCDataset, CSRDataset
else: # pragma: no cover
try: # only exists in anndata 0.11+
from anndata.abc import CSCDataset, CSRDataset
Expand Down
Loading