diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ff0d0d..42b30b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,13 +6,17 @@ repos: - id: trailing-whitespace - id: no-commit-to-branch - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.10 + rev: v0.14.11 hooks: - id: ruff-check - args: [--fix, --exit-non-zero-on-fix] + args: [--fix] - id: ruff-check args: [--preview, --select=CPY] - id: ruff-format +- repo: https://github.com/allganize/ty-pre-commit + rev: v0.0.11 + hooks: + - id: ty-check - repo: https://github.com/tox-dev/pyproject-fmt rev: v2.11.1 hooks: @@ -21,89 +25,3 @@ repos: rev: v2.3.10 hooks: - id: biome-format -- repo: https://github.com/H4rryK4ne/update-mypy-hook - rev: v0.3.0 - hooks: - - id: update-mypy-hook -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.19.1 - hooks: - - id: mypy - args: [--config-file=pyproject.toml, .] - pass_filenames: false - language_version: '3.13' - additional_dependencies: - - alabaster==1.0.0 - - anndata==0.12.6 - - array-api-compat==1.12.0 - - babel==2.17.0 - - certifi==2025.11.12 - - cffi==2.0.0 - - charset-normalizer==3.4.4 - - click==8.3.1 - - cloudpickle==3.1.2 - - colorama==0.4.6 ; sys_platform == 'win32' - - coverage==7.13.0 - - dask==2025.11.0 - - docutils==0.22.3 - - donfig==0.8.1.post1 - - execnet==2.1.2 - - fsspec==2025.12.0 - - google-crc32c==1.7.1 - - h5py==3.15.1 - - idna==3.11 - - imagesize==1.4.1 - - iniconfig==2.3.0 - - jinja2==3.1.6 - - joblib==1.5.2 - - legacy-api-wrap==1.5 - - llvmlite==0.46.0 - - locket==1.0.0 - - markdown-it-py==4.0.0 - - markupsafe==3.0.3 - - mdurl==0.1.2 - - natsort==8.4.0 - - numba==0.63.1 - - numcodecs==0.16.5 - - numpy==2.3.5 - - numpy-typing-compat==20251206.2.3 - - optype==0.15.0 - - packaging==25.0 - - pandas==2.3.3 - - partd==1.4.2 - - pluggy==1.6.0 - - pycparser==2.23 ; implementation_name != 'PyPy' - - pygments==2.19.2 - - pytest==9.0.2 - - pytest-codspeed==4.2.0 - - pytest-doctestplus==1.6.0 - - pytest-xdist==3.8.0 - - python-dateutil==2.9.0.post0 - - pytz==2025.2 - - pyyaml==6.0.3 - - requests==2.32.5 - - rich==14.2.0 - - roman-numerals==3.1.0 - - scikit-learn==1.8.0 - - scipy==1.16.3 - - scipy-stubs==1.16.3.3 - - six==1.17.0 - - snowballstemmer==3.0.1 - - sphinx==9.0.4 - - sphinxcontrib-applehelp==2.0.0 - - sphinxcontrib-devhelp==2.0.0 - - sphinxcontrib-htmlhelp==2.1.0 - - sphinxcontrib-jsmath==1.0.1 - - sphinxcontrib-qthelp==2.0.0 - - sphinxcontrib-serializinghtml==2.0.0 - - threadpoolctl==3.6.0 - - toolz==1.1.0 - - types-docutils==0.22.3.20251115 - - tzdata==2025.2 - - urllib3==2.6.1 - - zarr==3.1.5 -ci: - autoupdate_commit_msg: 'ci: pre-commit autoupdate' - skip: - - mypy # too big - - update-mypy-hook # offline? diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..7d2f8bc --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,8 @@ +{ + "recommendations": [ + "astral-sh.ty", + "biomejs.biome", + "tamasfe.even-better-toml", + "charliermarsh.ruff", + ], +} diff --git a/.vscode/settings.json b/.vscode/settings.json index ebc3a5d..a4efe30 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,4 +17,5 @@ }, "python.testing.pytestArgs": ["-vv", "--color=yes", "-m", "not benchmark"], "python.testing.pytestEnabled": true, + "python.languageServer": "None", // ty instead } diff --git a/pyproject.toml b/pyproject.toml index f7d5a2f..6ed45e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,14 +56,9 @@ doc = [ "sphinx>=9.0.1", "sphinx-autofixture>=0.4.1", ] -# for update-mypy-hook -mypy = [ - "fast-array-utils[full]", +typing = [ "scipy-stubs", - # TODO: replace sphinx with this: { include-group = "doc" }, - "sphinx", "types-docutils", - { include-group = "test" }, ] [tool.hatch.version] @@ -84,15 +79,21 @@ packages = [ "src/testing", "src/fast_array_utils" ] [tool.hatch.envs.default] installer = "uv" +[tool.hatch.envs.typecheck] +dependencies = [ "ty" ] +features = [ "full" ] +dependency-groups = [ "test", "doc", "typing" ] +scripts.run = "ty check" + [tool.hatch.envs.docs] -dependency-groups = [ "doc" ] +dependency-groups = [ "doc", "typing" ] scripts.build = "sphinx-build -M html docs docs/_build" scripts.clean = "git clean -fdX docs" scripts.open = "python -m webbrowser -t docs/_build/html/index.html" [tool.hatch.envs.hatch-test] default-args = [ ] -dependency-groups = [ "test-min" ] +dependency-groups = [ "test-min", "typing" ] # TODO: remove scipy once https://github.com/pypa/hatch/pull/2127 is released extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ] env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed" @@ -161,6 +162,9 @@ lint.isort.lines-after-imports = 2 lint.pydocstyle.convention = "numpy" lint.future-annotations = true +# Override the project-wide Python version for a developer scripts directory: +per-file-target-version."**/*.pyi" = "py313" + [tool.pytest] strict = true addopts = [ @@ -194,13 +198,9 @@ html.directory = "test-data/htmlcov" run.omit = [ "src/testing/*", "tests/*" ] report.exclude_also = [ "if TYPE_CHECKING:", "@numba[.]njit", "[.]{3}" ] -[tool.mypy] -strict = true -# https://github.com/dask/dask/issues/8853 -implicit_reexport = true -explicit_package_bases = true -mypy_path = [ "$MYPY_CONFIG_FILE_DIR/typings", "$MYPY_CONFIG_FILE_DIR/src" ] - [tool.pyright] stubPath = "./typings" reportPrivateUsage = false + +[tool.ty.environment] +extra-paths = [ "typings" ] diff --git a/src/fast_array_utils/_plugins/numba_sparse.py b/src/fast_array_utils/_plugins/numba_sparse.py index 3796af1..64bf63e 100644 --- a/src/fast_array_utils/_plugins/numba_sparse.py +++ b/src/fast_array_utils/_plugins/numba_sparse.py @@ -247,7 +247,7 @@ def overload_sparse_copy(inst: CSType) -> None | Callable[[CSType], CSType]: # nopython code: def copy(inst: CSType) -> CSType: # pragma: no cover - return _sparse_copy(inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape) # type: ignore[return-value] + return _sparse_copy(inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape) return copy diff --git a/src/fast_array_utils/conv/_to_dense.py b/src/fast_array_utils/conv/_to_dense.py index 6b5de88..21e2284 100644 --- a/src/fast_array_utils/conv/_to_dense.py +++ b/src/fast_array_utils/conv/_to_dense.py @@ -3,7 +3,7 @@ import warnings from functools import partial, singledispatch -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import numpy as np @@ -47,7 +47,7 @@ def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] msg = f"{order=!r} will probably be ignored: Dask can not be made to emit F-contiguous arrays reliably." warnings.warn(msg, RuntimeWarning, stacklevel=4) x = x.map_blocks(partial(to_dense, order=order, to_cpu_memory=to_cpu_memory)) - return x.compute() if to_cpu_memory else x # type: ignore[return-value] + return x.compute() if to_cpu_memory else x @to_dense_.register(types.CSDataset) @@ -58,7 +58,7 @@ def _to_dense_ooc(x: types.CSDataset, /, *, order: Literal["K", "A", "C", "F"] = msg = "to_cpu_memory must be True if x is an CS{R,C}Dataset" raise ValueError(msg) # TODO(flying-sheep): why is to_memory of type Any? # noqa: TD003 - return to_dense(cast("types.CSBase", x.to_memory()), order=sparse_order(x, order=order)) + return to_dense(x.to_memory(), order=sparse_order(x, order=order)) @to_dense_.register(types.CupyArray | types.CupySpMatrix) @@ -77,4 +77,4 @@ def sparse_order(x: types.spmatrix | types.sparray | types.CupySpMatrix | types. if order in {"K", "A"}: order = "F" if x.format == "csc" else "C" - return cast("Literal['C', 'F']", order) + return order diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 24712d8..015b7ea 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -25,7 +25,6 @@ from ._generic_ops import Ops from ._typing import NoDtypeOps, StatFunDtype, StatFunNoDtype - __all__ = ["is_constant", "max", "mean", "mean_var", "min", "sum"] @@ -87,9 +86,9 @@ def mean(x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTy @overload def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> NDArray[np.number[Any]]: ... @overload -def mean(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> types.CupyArray: ... +def mean(x: GpuArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None) -> types.CupyArray: ... @overload -def mean(x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None) -> types.DaskArray: ... +def mean(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: ToDType[Any] | None = None) -> types.DaskArray: ... def mean( @@ -144,7 +143,17 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup @overload def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ... @overload -def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ... +def mean_var( + x: types.DaskArray[CpuArray | GpuArray], /, *, axis: None = None, correction: int = 0 +) -> tuple[types.DaskArray[np.float64], types.DaskArray[np.float64]]: ... +@overload +def mean_var( + x: types.DaskArray[CpuArray], /, *, axis: Literal[0, 1], correction: int = 0 +) -> tuple[types.DaskArray[NDArray[np.float64]], types.DaskArray[NDArray[np.float64]]]: ... +@overload +def mean_var( + x: types.DaskArray[GpuArray], /, *, axis: Literal[0, 1], correction: int = 0 +) -> tuple[types.DaskArray[types.CupyArray], types.DaskArray[types.CupyArray]]: ... def mean_var( @@ -157,7 +166,8 @@ def mean_var( tuple[np.float64, np.float64] | tuple[NDArray[np.float64], NDArray[np.float64]] | tuple[types.CupyArray, types.CupyArray] - | tuple[types.DaskArray, types.DaskArray] + | tuple[types.DaskArray[np.float64], types.DaskArray[np.float64]] + | tuple[types.DaskArray[NDArray[np.float64]], types.DaskArray[NDArray[np.float64]]] ): """Mean and variance over both or one axis. @@ -201,7 +211,7 @@ def mean_var( from ._mean_var import mean_var_ validate_axis(x.ndim, axis) - return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return] + return mean_var_(x, axis=axis, correction=correction) @overload diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index 79df974..b4cccfc 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -88,7 +88,7 @@ def _generic_op_cs( if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected - x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type] + x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis)) # old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv diff --git a/src/fast_array_utils/stats/_is_constant.py b/src/fast_array_utils/stats/_is_constant.py index 1ac95d3..4bd386c 100644 --- a/src/fast_array_utils/stats/_is_constant.py +++ b/src/fast_array_utils/stats/_is_constant.py @@ -92,7 +92,7 @@ def _is_constant_dask(a: types.DaskArray, /, *, axis: Literal[0, 1] | None = Non (a == a[0, 0].compute()).all() if isinstance(a._meta, np.ndarray) # noqa: SLF001 else da.map_overlap( - lambda a: np.array([[is_constant(a)]]), # type: ignore[arg-type] + lambda a: np.array([[is_constant(a)]]), a, # use asymmetric overlaps to avoid unnecessary computation depth=dict.fromkeys(range(a.ndim), (0, 1)), diff --git a/src/fast_array_utils/stats/_mean.py b/src/fast_array_utils/stats/_mean.py index ba08164..af51b9d 100644 --- a/src/fast_array_utils/stats/_mean.py +++ b/src/fast_array_utils/stats/_mean.py @@ -24,6 +24,6 @@ def mean_( axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray: - total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type] + total = sum(x, axis=axis, dtype=dtype) n = np.prod(x.shape) if axis is None else x.shape[axis] - return total / n # type: ignore[no-any-return] + return total / n diff --git a/src/fast_array_utils/stats/_mean_var.py b/src/fast_array_utils/stats/_mean_var.py index 9037567..fa9ce57 100644 --- a/src/fast_array_utils/stats/_mean_var.py +++ b/src/fast_array_utils/stats/_mean_var.py @@ -29,7 +29,7 @@ def mean_var_( tuple[NDArray[np.float64], NDArray[np.float64]] | tuple[types.CupyArray, types.CupyArray] | tuple[np.float64, np.float64] - | tuple[types.DaskArray, types.DaskArray] + | tuple[types.DaskArray[NDArray[np.float64]], types.DaskArray[NDArray[np.float64]]] ): from . import mean diff --git a/src/fast_array_utils/stats/_power.py b/src/fast_array_utils/stats/_power.py index 8387836..af6a59a 100644 --- a/src/fast_array_utils/stats/_power.py +++ b/src/fast_array_utils/stats/_power.py @@ -21,23 +21,23 @@ def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr: """Take array or matrix to a power.""" # This wrapper is necessary because TypeVars can’t be used in `singledispatch` functions - return _power(x, n, dtype=dtype) # type: ignore[return-value] + return _power(x, n, dtype=dtype) @singledispatch def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array: if TYPE_CHECKING: assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix) - return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator] + return x**n if dtype is None else np.power(x, n, dtype=dtype) @_power.register(types.CSBase | types.CupyCSMatrix) def _power_cs[Mat: types.CSBase | types.CupyCSMatrix](x: Mat, n: int, /, dtype: DTypeLike | None = None) -> Mat: new_data = power(x.data, n, dtype=dtype) - return type(x)((new_data, x.indices, x.indptr), shape=x.shape, dtype=new_data.dtype) # type: ignore[call-overload,return-value] + return type(x)((new_data, x.indices, x.indptr), shape=x.shape, dtype=new_data.dtype) @_power.register(types.DaskArray) def _power_dask(x: types.DaskArray, n: int, /, dtype: DTypeLike | None = None) -> types.DaskArray: meta = x._meta.astype(dtype or x.dtype) # noqa: SLF001 - return x.map_blocks(lambda c: power(c, n, dtype=dtype), dtype=dtype, meta=meta) # type: ignore[type-var,arg-type] + return x.map_blocks(lambda c: power(c, n, dtype=dtype), dtype=dtype, meta=meta) diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index bc1f650..2ba8114 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -52,9 +52,9 @@ def _dask_inner(x: types.DaskArray, op: Ops, /, *, axis: Literal[0, 1] | None, d def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: if isinstance(a, types.CupyArray): a = a.get() - return a.reshape(())[()] # type: ignore[return-value] + return a.reshape(())[()] - return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) def _dask_block( @@ -75,7 +75,7 @@ def _dask_block( fns = {fn.__name__: fn for fn in (min, max, sum)} axis = _normalize_axis(axis, a.ndim) - rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload] + rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) shape = _get_shape(rv, axis=axis, keepdims=keepdims) return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) @@ -90,7 +90,7 @@ def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1] | None: case (0, 1) | (1, 0): axis = None case _: # pragma: no cover - raise AxisError(axis, ndim) # type: ignore[call-overload] + raise AxisError(axis, ndim) if axis == 0 and ndim == 1: return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays return axis diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index c1fbf26..0c9d19d 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -106,7 +106,7 @@ if TYPE_CHECKING: - from anndata.abc import CSCDataset, CSRDataset # type: ignore[import-untyped] + from anndata.abc import CSCDataset, CSRDataset else: # pragma: no cover try: # only exists in anndata 0.11+ from anndata.abc import CSCDataset, CSRDataset diff --git a/src/fast_array_utils/typing.py b/src/fast_array_utils/typing.py index aee112c..80caa4f 100644 --- a/src/fast_array_utils/typing.py +++ b/src/fast_array_utils/typing.py @@ -20,5 +20,5 @@ """Arrays and matrices stored in GPU memory.""" # TODO(flying-sheep): types.CSDataset # noqa: TD003 -type DiskArray = types.H5Dataset | types.ZarrArray # type: ignore[type-arg] +type DiskArray = types.H5Dataset | types.ZarrArray """Arrays and matrices stored on disk.""" diff --git a/src/testing/fast_array_utils/__init__.py b/src/testing/fast_array_utils/__init__.py index e2b7876..eb6bfaa 100644 --- a/src/testing/fast_array_utils/__init__.py +++ b/src/testing/fast_array_utils/__init__.py @@ -38,13 +38,10 @@ for (mod, flags) in [("scipy.sparse", Flags(0)), ("cupyx.scipy.sparse", Flags.Gpu)] ), ) -_TP_DASK = tuple( - ArrayType("dask.array", "Array", Flags.Dask | t.flags, inner=t) # type: ignore[type-var] - for t in cast("tuple[ArrayType[CpuArray | GpuArray, None], ...]", _TP_MEM) -) +_TP_DASK = tuple(ArrayType("dask.array", "Array", Flags.Dask | t.flags, inner=t) for t in cast("tuple[ArrayType[CpuArray | GpuArray, None], ...]", _TP_MEM)) _TP_DISK_DENSE = tuple(ArrayType(m, n, Flags.Any | Flags.Disk) for m, n in [("h5py", "Dataset"), ("zarr", "Array")]) _TP_DISK_SPARSE = tuple( - ArrayType("anndata.abc", n, Flags.Any | Flags.Disk | Flags.Sparse, inner=t) # type: ignore[type-var] + ArrayType("anndata.abc", n, Flags.Any | Flags.Disk | Flags.Sparse, inner=t) for t in cast("tuple[ArrayType[DiskArray, None], ...]", _TP_DISK_DENSE) for n in ["CSRDataset", "CSCDataset"] ) diff --git a/src/testing/fast_array_utils/_array_type.py b/src/testing/fast_array_utils/_array_type.py index f4ae663..aa214ab 100644 --- a/src/testing/fast_array_utils/_array_type.py +++ b/src/testing/fast_array_utils/_array_type.py @@ -4,7 +4,6 @@ from __future__ import annotations import enum -import sys from dataclasses import KW_ONLY, dataclass, field from functools import cached_property, partial from importlib.metadata import version @@ -74,10 +73,12 @@ class ConversionContext: hdf5_file: h5py.File # TODO(flying-sheep): ReadOnly -if TYPE_CHECKING or sys.version_info >= (3, 13): - # TODO(flying-sheep): move vars into type parameter syntax # noqa: TD003 - Arr = TypeVar("Arr", bound="ExtendedArray", default="Array") - Inner = TypeVar("Inner", bound="ArrayType[InnerArray, None] | None", default="Any") +if TYPE_CHECKING: + # TODO(flying-sheep): Python 3.13: move vars into type parameter syntax # noqa: TD003 + import typing_extensions as te + + Arr = te.TypeVar("Arr", bound="ExtendedArray", default="Array") + Inner = te.TypeVar("Inner", bound="InnerArray | None", default="InnerArray") else: Arr = TypeVar("Arr") Inner = TypeVar("Inner") @@ -106,7 +107,7 @@ class ArrayType(Generic[Arr, Inner]): # noqa: UP046 _: KW_ONLY - inner: Inner = None # type: ignore[assignment] + inner: ArrayType[Inner, None] = None # ty:ignore[invalid-assignment,invalid-type-arguments]: can’t do `if Inner extends None` """Inner array type (e.g. for dask).""" conversion_context: ConversionContext | None = field(default=None, compare=False) """Conversion context required for converting to h5py.""" @@ -146,7 +147,7 @@ def cls(self) -> type[Arr]: # noqa: PLR0911 return cast("type[Arr]", zarr.Array) case "anndata.abc", ("CSCDataset" | "CSRDataset") as cls_name, _: - import anndata.abc # type: ignore[import-untyped] + import anndata.abc return cast("type[Arr]", getattr(anndata.abc, cls_name)) case _: @@ -191,7 +192,7 @@ def random( return cast( "Arr", arr.map_blocks( - lambda x: self.random(x.shape, dtype=x.dtype, gen=gen, density=density), # type: ignore[attr-defined] + lambda x: self.random(x.shape, dtype=x.dtype, gen=gen, density=density), dtype=dtype, ), ) @@ -277,7 +278,7 @@ def _to_zarr_array(cls, x: ArrayLike | Array, /, *, dtype: DTypeLike | None = No def _to_cs_dataset(self, x: ArrayLike | Array, /, *, dtype: DTypeLike | None = None) -> types.CSDataset: """Convert to a scipy sparse dataset.""" - import anndata.io # type: ignore[import-untyped] + import anndata.io from scipy.sparse import csc_array, csr_array assert self.inner is not None @@ -317,7 +318,7 @@ def _to_scipy_sparse( x = to_dense(x, to_cpu_memory=True) cls = cast("type[types.CSBase]", cls or self.cls) - return cls(x, dtype=dtype) # type: ignore[arg-type] + return cls(x, dtype=dtype) def _to_cupy_array(self, x: ArrayLike | Array, /, *, dtype: DTypeLike | None = None) -> types.CupyArray: import cupy as cu @@ -341,7 +342,7 @@ def _to_cupy_sparse( if not isinstance(x, types.spmatrix | types.sparray | types.CupyArray | types.CupySpMatrix): x = self._to_cupy_array(x, dtype=dtype) - return self.cls(x) # type: ignore[call-arg,arg-type, return-value] + return self.cls(x) def random_array( @@ -355,7 +356,7 @@ def random_array( f: MkArray match np.dtype(dtype or "f").kind: case "f": - f = rng.random # type: ignore[assignment] + f = rng.random case "i" | "u": f = partial(rng.integers, 0, 10_000) case _: diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 2500922..b0fb08f 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -145,7 +145,7 @@ def __init__(self, request: pytest.FixtureRequest) -> None: self._request = request @property # This is intentionally not cached and creates a new file on each access - def hdf5_file(self) -> h5py.File: # type: ignore[override] + def hdf5_file(self) -> h5py.File: import h5py try: # If we’re being called in a test or function-scoped fixture, use the test `tmp_path` diff --git a/tests/conftest.py b/tests/conftest.py index d389f5c..999b04b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,5 +43,5 @@ def viz(obj: object) -> None: @pytest.fixture(scope="session", params=COO_PARAMS) -def coo_matrix_type(request: pytest.FixtureRequest) -> ArrayType[types.COOBase | types.CupyCOOMatrix]: - return cast("ArrayType[types.COOBase | types.CupyCOOMatrix]", request.param) +def coo_matrix_type(request: pytest.FixtureRequest) -> ArrayType[types.COOBase | types.CupyCOOMatrix, None]: + return request.param diff --git a/tests/test_stats.py b/tests/test_stats.py index 334d1fd..606ae23 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -26,10 +26,13 @@ from numpy.typing import NDArray from pytest_codspeed import BenchmarkFixture - from fast_array_utils.stats._typing import Array, DTypeIn, DTypeOut, NdAndAx, StatFunNoDtype + from fast_array_utils.stats._typing import DTypeIn, DTypeOut, NdAndAx, StatFunNoDtype + from fast_array_utils.types import CSBase, CupyCSMatrix from fast_array_utils.typing import CpuArray, DiskArray, GpuArray from testing.fast_array_utils import ArrayType + type ArrayNoSparseDS = CpuArray | GpuArray | DiskArray | types.DaskArray + pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")] @@ -41,7 +44,7 @@ ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)} -def _xfail_if_old_scipy(array_type: ArrayType[Any], ndim: Literal[1, 2]) -> pytest.MarkDecorator: +def _xfail_if_old_scipy(array_type: ArrayType[Any, Any], ndim: Literal[1, 2]) -> pytest.MarkDecorator: cond = ndim == 1 and bool(array_type.flags & Flags.Sparse) and Version(version("scipy")) < Version("1.14") return pytest.mark.xfail(cond, reason="Sparse matrices don’t support 1d arrays") @@ -64,7 +67,7 @@ def ndim(ndim_and_axis: NdAndAx, array_type: ArrayType) -> Literal[1, 2]: return check_ndim(array_type, ndim_and_axis[0]) -def check_ndim(array_type: ArrayType, ndim: Literal[1, 2]) -> Literal[1, 2]: +def check_ndim(array_type: ArrayType[Any, Any], ndim: Literal[1, 2]) -> Literal[1, 2]: inner_cls = array_type.inner.cls if array_type.inner else array_type.cls if ndim != 2 and issubclass(inner_cls, types.CSMatrix | types.CupyCSMatrix): pytest.skip("CSMatrix only supports 2D") @@ -94,9 +97,9 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: @pytest.fixture def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: - np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in)) + np_arr = np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in) if np.dtype(dtype_in).kind == "f": - np_arr /= 4 # type: ignore[misc] + np_arr /= 4 np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() @@ -104,23 +107,24 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: def to_np_dense_checked( - stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray, axis: Literal[0, 1] | None, arr: CpuArray | GpuArray | DiskArray | types.DaskArray + stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray[NDArray[DTypeOut]], axis: Literal[0, 1] | None, arr: ArrayNoSparseDS ) -> NDArray[DTypeOut] | np.number[Any]: match axis, arr: case _, types.DaskArray(): assert isinstance(stat, types.DaskArray), type(stat) - stat = stat.compute() # type: ignore[assignment] - return to_np_dense_checked(stat, axis, arr.compute()) + stat = cast("NDArray[DTypeOut] | np.number[Any]", stat.compute()) + return to_np_dense_checked(stat, axis, arr.compute()) # ty:ignore[possibly-missing-attribute]: https://github.com/astral-sh/ty/issues/561 case None, _: assert isinstance(stat, np.floating | np.integer), type(stat) + return stat case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix() | types.CupyCOOMatrix(): assert isinstance(stat, types.CupyArray), type(stat) - return to_np_dense_checked(stat.get(), axis, arr.get()) + return to_np_dense_checked(stat.get(), axis, arr.get()) # ty:ignore[possibly-missing-attribute]: https://github.com/astral-sh/ty/issues/561 case 0 | 1, _: assert isinstance(stat, np.ndarray), type(stat) + return cast("NDArray[DTypeOut] | np.number[Any]", stat) case _: pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(stat)}") - return stat @pytest.fixture(scope="session") @@ -142,7 +146,7 @@ def pbmc64k_reduced_raw() -> sps.csr_array[np.float32]: @pytest.mark.parametrize("func", STAT_FUNCS) @pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"]) def test_ndim_error( - request: pytest.FixtureRequest, array_type: ArrayType[Array], func: StatFunNoDtype, ndim: Literal[1, 2], axis: Literal[0, 1] | None + request: pytest.FixtureRequest, array_type: ArrayType[ArrayNoSparseDS], func: StatFunNoDtype, ndim: Literal[1, 2], axis: Literal[0, 1] | None ) -> None: request.applymarker(_xfail_if_old_scipy(array_type, ndim)) check_ndim(array_type, ndim) @@ -159,7 +163,7 @@ def test_ndim_error( @pytest.mark.array_type(skip=ATS_SPARSE_DS) def test_sum( request: pytest.FixtureRequest, - array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], + array_type: ArrayType[ArrayNoSparseDS], dtype_in: type[DTypeIn], dtype_arg: type[DTypeOut] | None, axis: Literal[0, 1] | None, @@ -173,7 +177,7 @@ def test_sum( assert arr.dtype == dtype_in sum_ = stats.sum(arr, axis=axis, dtype=dtype_arg) - sum_ = to_np_dense_checked(sum_, axis, arr) # type: ignore[arg-type] + sum_ = to_np_dense_checked(sum_, axis, arr) assert sum_.shape == () if axis is None else arr.shape[axis], (sum_.shape, arr.shape) @@ -203,7 +207,7 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray @pytest.mark.array_type(skip=ATS_SPARSE_DS) @pytest.mark.parametrize("func", [stats.min, stats.max]) -def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None, func: StatFunNoDtype) -> None: +def test_min_max(array_type: ArrayType[ArrayNoSparseDS], axis: Literal[0, 1] | None, func: StatFunNoDtype) -> None: rng = np.random.default_rng(0) np_arr = rng.random((100, 100)) arr = array_type(np_arr) @@ -237,11 +241,13 @@ def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1] @pytest.mark.array_type(skip=ATS_SPARSE_DS) -def test_mean(request: pytest.FixtureRequest, array_type: ArrayType[Array], axis: Literal[0, 1] | None, np_arr: NDArray[DTypeIn], ndim: Literal[1, 2]) -> None: +def test_mean( + request: pytest.FixtureRequest, array_type: ArrayType[ArrayNoSparseDS], axis: Literal[0, 1] | None, np_arr: NDArray[DTypeIn], ndim: Literal[1, 2] +) -> None: request.applymarker(_xfail_if_old_scipy(array_type, ndim)) arr = array_type(np_arr) - result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777 + result = stats.mean(arr, axis=axis) if isinstance(result, types.DaskArray): result = result.compute() if isinstance(result, types.CupyArray | types.CupyCSMatrix): @@ -254,7 +260,7 @@ def test_mean(request: pytest.FixtureRequest, array_type: ArrayType[Array], axis @pytest.mark.array_type(skip=Flags.Disk) def test_mean_var( request: pytest.FixtureRequest, - array_type: ArrayType[CpuArray | GpuArray | types.DaskArray], + array_type: ArrayType[CpuArray | GpuArray, None] | ArrayType[types.DaskArray[CpuArray | GpuArray], CpuArray | GpuArray], axis: Literal[0, 1] | None, np_arr: NDArray[DTypeIn], ndim: Literal[1, 2], @@ -264,14 +270,16 @@ def test_mean_var( mean, var = stats.mean_var(arr, axis=axis, correction=1) if isinstance(mean, types.DaskArray) and isinstance(var, types.DaskArray): - mean, var = mean.compute(), var.compute() # type: ignore[assignment] + mean, var = mean.compute(), var.compute() if isinstance(mean, types.CupyArray) and isinstance(var, types.CupyArray): mean, var = mean.get(), var.get() + assert isinstance(mean, np.ndarray | np.floating) + assert isinstance(var, np.ndarray | np.floating) mean_expected = np.mean(np_arr, axis=axis) var_expected = np.var(np_arr, axis=axis, ddof=1) np.testing.assert_array_equal(mean, mean_expected) - np.testing.assert_array_almost_equal(var, var_expected) # type: ignore[arg-type] + np.testing.assert_array_almost_equal(var, var_expected) # ty:ignore[invalid-argument-type]: can’t derive element type @pytest.mark.skipif(not find_spec("sklearn"), reason="sklearn not installed") @@ -312,7 +320,9 @@ def test_mean_var_sparse_32(array_type: ArrayType[types.CSArray], subtests: pyte @pytest.mark.array_type({at for at in SUPPORTED_TYPES if at.flags & Flags.Sparse and at.flags & Flags.Dask}) -def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], pbmc64k_reduced_raw: sps.csr_array[np.float32]) -> None: +def test_mean_var_pbmc_dask( + array_type: ArrayType[types.DaskArray[CSBase | CupyCSMatrix], CSBase | CupyCSMatrix], pbmc64k_reduced_raw: sps.csr_array[np.float32] +) -> None: """Test float32 precision for bigger data. This test is flaky for sparse-in-dask for some reason. @@ -321,6 +331,7 @@ def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], pbmc64k_redu arr = array_type(mat) mean_mat, var_mat = stats.mean_var(mat, axis=0, correction=1) + # partial reproducer in case it’s not fixed in the next release: https://play.ty.dev/9eb8530f-fadc-4019-863f-ebf3096c0f3c mean_arr, var_arr = (to_np_dense_checked(a, 0, arr) for a in stats.mean_var(arr, axis=0, correction=1)) rtol = 1.0e-5 if array_type.flags & Flags.Gpu else 1.0e-7 @@ -364,7 +375,7 @@ def test_is_constant( @pytest.mark.array_type(Flags.Dask, skip=ATS_CUPY_SPARSE) -def test_dask_constant_blocks(dask_viz: Callable[[object], None], array_type: ArrayType[types.DaskArray, Any]) -> None: +def test_dask_constant_blocks(dask_viz: Callable[[object], None], array_type: ArrayType[types.DaskArray]) -> None: """Tests if is_constant works if each chunk is individually constant.""" x_np = np.repeat(np.repeat(np.arange(4, dtype=np.float64).reshape(2, 2), 2, axis=0), 2, axis=1) x = array_type(x_np) @@ -373,7 +384,7 @@ def test_dask_constant_blocks(dask_viz: Callable[[object], None], array_type: Ar result = stats.is_constant(x, axis=None) dask_viz(result) - assert result.compute() is False # type: ignore[comparison-overlap] + assert result.compute() is False @pytest.mark.benchmark diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index a78578a..204f96a 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -51,8 +51,8 @@ def test_conv_other(array_type: ArrayType, other_array_type: ArrayType) -> None: @pytest.mark.array_type(skip=Flags.Dask | Flags.Disk | Flags.Gpu) def test_conv_extra( - array_type: ArrayType[NDArray[np.number[Any]] | types.CSBase], - coo_matrix_type: ArrayType[types.COOBase | types.CupyCOOMatrix], + array_type: ArrayType[NDArray[np.number[Any]] | types.CSBase, None], + coo_matrix_type: ArrayType[types.COOBase | types.CupyCOOMatrix, None], ) -> None: src_arr = array_type(np.arange(12).reshape(3, 4), dtype=np.float32) arr = coo_matrix_type(src_arr) diff --git a/tests/test_to_dense.py b/tests/test_to_dense.py index 1d75969..a98926b 100644 --- a/tests/test_to_dense.py +++ b/tests/test_to_dense.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: from collections.abc import Iterable - from typing import Literal + from typing import Literal, TypeAlias from fast_array_utils.typing import CpuArray, DiskArray, GpuArray from testing.fast_array_utils import ArrayType - type Array = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray + Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray # noqa: UP040 https://github.com/astral-sh/ty/issues/2488 type ExtendedArray = Array | types.COOBase | types.CupyCOOMatrix @@ -52,7 +52,7 @@ def test_to_dense(array_type: ArrayType[Array], *, order: Literal["K", "C", "F"] @pytest.mark.parametrize("to_cpu_memory", [True, False], ids=["to_cpu_memory", "not_to_cpu_memory"]) @pytest.mark.parametrize("order", argvalues=["K", "C", "F"]) # “A” behaves like “K” -def test_to_dense_extra(coo_matrix_type: ArrayType[types.COOBase | types.CupyCOOMatrix], *, order: Literal["K", "C", "F"], to_cpu_memory: bool) -> None: +def test_to_dense_extra(coo_matrix_type: ArrayType[types.COOBase | types.CupyCOOMatrix, None], *, order: Literal["K", "C", "F"], to_cpu_memory: bool) -> None: src_mtx = coo_matrix_type([[1, 2, 3], [4, 5, 6]], dtype=np.float32) with WARNS_NUMBA if not find_spec("numba") else nullcontext(): @@ -63,21 +63,21 @@ def test_to_dense_extra(coo_matrix_type: ArrayType[types.COOBase | types.CupyCOO assert_expected_order(src_mtx, arr, order=order) -def assert_expected_cls(orig: ExtendedArray, converted: Array, *, to_cpu_memory: bool) -> None: - match (to_cpu_memory, orig): +def assert_expected_cls(orig: ExtendedArray | np.number, converted: Array | np.number, *, to_cpu_memory: bool) -> None: + match to_cpu_memory, orig: case False, types.DaskArray(): assert isinstance(converted, types.DaskArray) - assert_expected_cls(orig.compute(), converted.compute(), to_cpu_memory=to_cpu_memory) + assert_expected_cls(orig.compute(), converted.compute(), to_cpu_memory=to_cpu_memory) # ty:ignore[possibly-missing-attribute]: https://github.com/astral-sh/ty/issues/561 case False, types.CupyArray() | types.CupySpMatrix(): assert isinstance(converted, types.CupyArray) case _: assert isinstance(converted, np.ndarray) -def assert_expected_order(orig: ExtendedArray, converted: Array, *, order: Literal["K", "C", "F"]) -> None: +def assert_expected_order(orig: ExtendedArray | np.number, converted: Array | np.number, *, order: Literal["K", "C", "F"]) -> None: match converted: - case types.CupyArray() | np.ndarray(): - orders = {order_exp: converted.flags[f"{order_exp}_CONTIGUOUS"] for order_exp in (get_orders(orig) if order == "K" else {order})} # type: ignore[index] + case types.CupyArray() | np.ndarray() | np.number(): + orders = {order_exp: converted.flags[order_exp] for order_exp in (get_orders(orig) if order == "K" else (order,))} assert any(orders.values()), orders case types.DaskArray(): assert_expected_order(orig, converted.compute(), order=order) @@ -85,14 +85,14 @@ def assert_expected_order(orig: ExtendedArray, converted: Array, *, order: Liter pytest.fail(f"Unsupported array type: {type(converted)}") -def get_orders(orig: ExtendedArray) -> Iterable[Literal["C", "F"]]: +def get_orders(orig: ExtendedArray | np.number) -> Iterable[Literal["C", "F"]]: """Get the orders of an array. Numpy arrays with at most one axis of a length >1 are valid in both orders. So are COO sparse matrices/arrays. """ match orig: - case np.ndarray() | types.CupyArray(): + case np.ndarray() | types.CupyArray() | np.number(): if orig.flags.c_contiguous: yield "C" if orig.flags.f_contiguous: diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index 2995181..04827f4 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -34,7 +34,7 @@ class ndarray: def __getitem__( # never returns scalars self, index: int | slice | EllipsisType | tuple[int | slice | EllipsisType | None, ...] ) -> Self: ... - def __eq__(self, value: object) -> ndarray: ... # type: ignore[override] + def __eq__(self, value: object) -> ndarray: ... # ty:ignore[invalid-method-override] def __power__(self, other: int) -> Self: ... # methods diff --git a/typings/dask/array/core.pyi b/typings/dask/array/core.pyi index 7f65dd2..ca2b43d 100644 --- a/typings/dask/array/core.pyi +++ b/typings/dask/array/core.pyi @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MPL-2.0 # pyright: reportIncompatibleMethodOverride=false from collections.abc import Callable, Sequence -from typing import Any, Literal, Never, override +from typing import Any, Concatenate, Generic, Literal, TypeVar import cupy import cupyx.scipy.sparse @@ -12,8 +12,9 @@ from numpy.typing import DTypeLike, NDArray from ..utils import SerializableLock type _Chunks = tuple[int, ...] | tuple[tuple[int, ...], ...] -type _Array = ( +type _Chunk = ( NDArray[Any] + | np.number[Any] | scipy.sparse.csr_array | scipy.sparse.csc_array | scipy.sparse.csr_matrix @@ -23,30 +24,32 @@ type _Array = ( | cupyx.scipy.sparse.csc_matrix ) -class BlockView: +# https://github.com/astral-sh/ty/issues/2104 +C = TypeVar("C", bound=_Chunk, default=_Chunk) # noqa: PYI001 + +class BlockView(Generic[C]): size: int shape: tuple[int, ...] - def __getitem__(self, index: object) -> Array: ... - def ravel(self) -> list[Array]: ... + def __getitem__(self, index: object) -> Array[C]: ... + def ravel(self) -> list[Array[C]]: ... -class Array: +class Array(Generic[C]): # array methods and attrs ndim: int shape: tuple[int, ...] dtype: np.dtype[Any] - @override - def __eq__(self, value: object, /) -> Array: ... # type: ignore[override] + def __eq__(self, value: object, /) -> Array: ... # ty:ignore[invalid-method-override] def __getitem__(self, index: object) -> Array: ... def all(self) -> Array: ... # dask methods and attrs - _meta: _Array - blocks: BlockView + _meta: C + blocks: BlockView[C] chunks: tuple[tuple[int, ...], ...] chunksize: tuple[int, ...] - def compute(self) -> _Array: ... + def compute(self) -> C: ... def visualize( self, filename: str = "mydask", @@ -60,11 +63,10 @@ class Array: verbose: bool = False, engine: str = "ipycytoscape", ) -> object: ... - def map_blocks( + def map_blocks[C2: _Chunk, **P]( self, - # TODO(flying-sheep): make this generic, _Array the default # noqa: TD003 - func: Callable[[object], object], - *args: Never, + func: Callable[Concatenate[C, P], C2], + *args: P.args, name: str | None = None, token: str | None = None, dtype: DTypeLike | None = None, @@ -72,24 +74,23 @@ class Array: drop_axis: Sequence[int] | int | None = None, new_axis: Sequence[int] | int | None = None, enforce_ndim: bool = False, - meta: _Array | None = None, - **kwargs: object, - ) -> Array: ... + meta: C2 | None = None, + **kwargs: P.kwargs, + ) -> Array[C2]: ... -def from_array( - x: _Array, +def from_array[C: _Chunk]( + x: C, chunks: _Chunks | str | Literal["auto"] = "auto", # noqa: PYI051 name: str | None = None, lock: bool | SerializableLock = False, asarray: bool | None = None, fancy: bool = True, getitem: object = None, # undocumented - meta: _Array | None = None, + meta: C | None = None, inline_array: bool = False, -) -> Array: ... -def map_blocks( - # TODO(flying-sheep): make this generic, _Array the default # noqa: TD003 - func: Callable[[object], object], +) -> Array[C]: ... +def map_blocks[C: _Chunk]( + func: Callable[[object], C], *args: Array, name: str | None = None, token: str | None = None, @@ -100,4 +101,4 @@ def map_blocks( enforce_ndim: bool = False, meta: object | None = None, **kwargs: object, -) -> Array: ... +) -> Array[C]: ... diff --git a/typings/dask/array/overlap.pyi b/typings/dask/array/overlap.pyi index ce1a65a..ff65826 100644 --- a/typings/dask/array/overlap.pyi +++ b/typings/dask/array/overlap.pyi @@ -2,14 +2,14 @@ from collections.abc import Callable from typing import Literal -from .core import Array, _Array +from .core import Array, _Chunk type _Depth = int | tuple[int, ...] | dict[int, _Depth] type _Boundary = Literal["reflect", "periodic", "nearest", "none"] | int type _Boundaries = _Boundary | tuple[_Boundary, ...] | dict[int, _Boundary] def map_overlap( - func: Callable[[_Array], _Array], + func: Callable[[_Chunk], _Chunk], *args: Array, depth: _Depth | list[_Depth] | None = None, boundary: _Boundaries | list[_Boundaries] | None = None, diff --git a/typings/dask/array/reductions.pyi b/typings/dask/array/reductions.pyi index 0007725..29697ca 100644 --- a/typings/dask/array/reductions.pyi +++ b/typings/dask/array/reductions.pyi @@ -5,16 +5,16 @@ from typing import Any, Protocol, overload from numpy.typing import ArrayLike, DTypeLike, NDArray -from .core import Array, _Array +from .core import Array, _Chunk class _Chunk(Protocol): @overload - def __call__(self, x_chunk: _Array, /, *, weights_chunk: NDArray[Any] | None = None, axis: tuple[int, ...], keepdims: bool, **kwargs: object) -> _Array: ... + def __call__(self, x_chunk: _Chunk, /, *, weights_chunk: NDArray[Any] | None = None, axis: tuple[int, ...], keepdims: bool, **kwargs: object) -> _Chunk: ... @overload - def __call__(self, x_chunk: _Array, /, *, axis: tuple[int, ...], keepdims: bool, **kwargs: object) -> _Array: ... + def __call__(self, x_chunk: _Chunk, /, *, axis: tuple[int, ...], keepdims: bool, **kwargs: object) -> _Chunk: ... class _CB(Protocol): - def __call__(self, x_chunk: _Array, /, *, axis: tuple[int, ...], keepdims: bool, **kwargs: object) -> _Array: ... + def __call__(self, x_chunk: _Chunk, /, *, axis: tuple[int, ...], keepdims: bool, **kwargs: object) -> _Chunk: ... def reduction( x: Array, @@ -30,6 +30,6 @@ def reduction( out: Array | None = None, concatenate: bool = True, output_size: int = 1, - meta: _Array | None = None, + meta: _Chunk | None = None, weights: ArrayLike | None = None, ) -> Array: ...