Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,5 @@ repos:
rev: v1.19.1
hooks:
- id: mypy
additional_dependencies: ["types-attrs"]
files: ^src/|^tests/
args: ["--config-file", "pyproject.toml"]
4 changes: 2 additions & 2 deletions src/mritk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
mixed,
r1,
segmentation,
stats,
statistics,
utils,
)

Expand All @@ -34,5 +34,5 @@
"mixed",
"hybrid",
"r1",
"stats",
"statistics",
]
6 changes: 3 additions & 3 deletions src/mritk/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.logging import RichHandler
from rich_argparse import RichHelpFormatter

from . import concentration, datasets, hybrid, info, looklocker, mixed, napari, r1, show, stats
from . import concentration, datasets, hybrid, info, looklocker, mixed, napari, r1, show, statistics


def version_info():
Expand Down Expand Up @@ -67,7 +67,7 @@ def setup_parser():
info_parser.add_argument("--json", action="store_true", help="Output information in JSON format")

stats_parser = subparsers.add_parser("stats", help="Compute MRI statistics", formatter_class=parser.formatter_class)
stats.add_arguments(stats_parser)
statistics.cli.add_arguments(stats_parser)

show_parser = subparsers.add_parser("show", help="Show MRI data in a terminal", formatter_class=parser.formatter_class)
show.add_arguments(show_parser)
Expand Down Expand Up @@ -132,7 +132,7 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No
file = args.pop("file")
info.nifty_info(file, json_output=args.pop("json"))
elif command == "stats":
stats.dispatch(args)
statistics.cli.dispatch(args)
elif command == "show":
show.dispatch(args)
elif command == "napari":
Expand Down
147 changes: 95 additions & 52 deletions src/mritk/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# Copyright (C) 2026 Simula Research Laboratory


import re
from pathlib import Path
from typing import Optional

Expand All @@ -14,6 +13,76 @@
import numpy.typing as npt


def load_mri_data(path: Path | str, dtype: npt.DTypeLike | None = None, orient: bool = True) -> tuple[np.ndarray, np.ndarray]:
"""Load MRI data from a file and return the data array and affine matrix.

Args:
path: Path to the MRI file.
dtype: Data type for the returned array.
orient: Whether to reorient the data.

Returns:
Tuple of (data, affine) arrays.
"""
filepath = Path(path)
suffix = check_suffix(filepath)
if suffix in (".nii", ".nii.gz"):
mri = nibabel.nifti1.load(filepath)
elif suffix in (".mgz", ".mgh"):
mri = nibabel.freesurfer.mghformat.load(filepath)
else:
raise ValueError(f"Invalid suffix {filepath}, should be either '.nii', or '.mgz'")

affine = mri.affine
if affine is None:
raise RuntimeError("MRI do not contain affine")

kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
data = np.asarray(mri.get_fdata("unchanged"), **kwargs)

if orient:
data, affine = data_reorientation(data, affine)

return data, affine


def save_mri_data(data: np.ndarray, affine: np.ndarray, path: Path | str, intent_code: Optional[int] = None):
"""Save MRI data to a file.

Args:
data: The MRI data array to save.
affine: The affine transformation matrix associated with the data.
path: Path to the file to save.
dtype: Data type for the saved array.
intent_code: Intent code for the saved file.
"""
save_path = Path(path)
suffix = check_suffix(save_path)
if suffix in (".nii", ".nii.gz"):
nii = nibabel.nifti1.Nifti1Image(data, affine)
if intent_code is not None:
nii.header.set_intent(intent_code)
nibabel.nifti1.save(nii, save_path)
elif suffix in (".mgz", ".mgh"):
mgh = nibabel.freesurfer.mghformat.MGHImage(data, affine)
if intent_code is not None:
mgh.header.set_intent(intent_code)
nibabel.freesurfer.mghformat.save(mgh, save_path)
else:
raise ValueError(f"Invalid suffix {save_path}, should be either '.nii', or '.mgz'")


def check_suffix(filepath: Path):
suffix = filepath.suffix
if suffix == ".gz":
suffixes = filepath.suffixes
if len(suffixes) >= 2 and suffixes[-2] == ".nii":
return ".nii.gz"
return suffix


class MRIData:
def __init__(self, data: np.ndarray, affine: np.ndarray):
self.data = data
Expand All @@ -23,52 +92,29 @@ def __init__(self, data: np.ndarray, affine: np.ndarray):
def shape(self) -> tuple[int, ...]:
return self.data.shape

def get_data(self):
return self.data

def get_metadata(self):
return self.affine

@property
def voxel_ml_volume(self) -> float:
# Calculate the volume of a single voxel in milliliters
voxel_volume_mm3 = abs(np.linalg.det(self.affine[:3, :3]))
voxel_volume_ml = voxel_volume_mm3 / 1000.0 # Convert from mm^3 to ml
return voxel_volume_ml

@classmethod
def from_file(cls, path: Path | str, dtype: npt.DTypeLike | None = None, orient: bool = True) -> "MRIData":
suffix_regex = re.compile(r".+(?P<suffix>(\.nii(\.gz|)|\.mg(z|h)))")
m = suffix_regex.match(Path(path).name)
if (m is not None) and (m.groupdict()["suffix"] in (".nii", ".nii.gz")):
mri = nibabel.nifti1.load(path)
elif (m is not None) and (m.groupdict()["suffix"] in (".mgz", ".mgh")):
mri = nibabel.freesurfer.mghformat.load(path)
else:
raise ValueError(f"Invalid suffix {path}, should be either '.nii', or '.mgz'")

affine = mri.affine
if affine is None:
raise RuntimeError("MRI do not contain affine")

kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
data = np.asarray(mri.get_fdata("unchanged"), **kwargs)

mri = cls(data=data, affine=affine)

if orient:
return data_reorientation(mri)
else:
return mri
data, affine = load_mri_data(path, dtype=dtype, orient=orient)
return cls(data=data, affine=affine)

def save(self, path: Path | str, dtype: npt.DTypeLike | None = None, intent_code: Optional[int] = None):
if dtype is None:
dtype = self.data.dtype
data = self.data.astype(dtype)

suffix_regex = re.compile(r".+(?P<suffix>(\.nii(\.gz|)|\.mg(z|h)))")
m = suffix_regex.match(Path(path).name)
if (m is not None) and (m.groupdict()["suffix"] in (".nii", ".nii.gz")):
nii = nibabel.nifti1.Nifti1Image(data, self.affine)
if intent_code is not None:
nii.header.set_intent(intent_code)
nibabel.nifti1.save(nii, path)
elif (m is not None) and (m.groupdict()["suffix"] in (".mgz", ".mgh")):
mgh = nibabel.freesurfer.mghformat.MGHImage(data, self.affine)
if intent_code is not None:
mgh.header.set_intent(intent_code)
nibabel.freesurfer.mghformat.save(mgh, path)
else:
raise ValueError(f"Invalid suffix {path}, should be either '.nii', or '.mgz'")
save_mri_data(data, self.affine, path, intent_code=intent_code)


def physical_to_voxel_indices(physical_coordinates: np.ndarray, affine: np.ndarray, round_coords: bool = True) -> np.ndarray:
Expand Down Expand Up @@ -147,7 +193,7 @@ def apply_affine(T: np.ndarray, X: np.ndarray) -> np.ndarray:
return A.dot(X.T).T + b


def data_reorientation(mri_data: MRIData) -> MRIData:
def data_reorientation(data: np.ndarray, affine: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Reorient the data array and affine matrix to the canonical orientation.

This function adjusts the data array layout (via transpositions and flips)
Expand All @@ -159,15 +205,16 @@ def data_reorientation(mri_data: MRIData) -> MRIData:
The physical coordinate system remains unchanged (e.g., RAS stays RAS).

Args:
mri_data: The input MRI data object containing the data array and affine.
data: The input MRI data array.
affine: The input affine transformation matrix.

Returns:
A new MRIData object with reoriented data and updated affine matrix.
A tuple containing the reoriented data array and updated affine matrix.
"""
A = mri_data.affine[:3, :3]
A = affine[:3, :3]
flips = np.sign(A[np.argmax(np.abs(A), axis=0), np.arange(3)]).astype(int)
permutes = np.argmax(np.abs(A), axis=0)
offsets = ((1 - flips) // 2) * (np.array(mri_data.data.shape[:3]) - 1)
offsets = ((1 - flips) // 2) * (np.array(data.shape[:3]) - 1)

# Index flip matrix
F = np.eye(4, dtype=int)
Expand All @@ -176,14 +223,10 @@ def data_reorientation(mri_data: MRIData) -> MRIData:

# Index permutation matrix
P = np.eye(4, dtype=int)[[*permutes, 3]]
affine = mri_data.affine @ F @ P
affine = affine @ F @ P
inverse_permutes = np.argmax(P[:3, :3].T, axis=1)
data = (
mri_data.data[:: flips[0], :: flips[1], :: flips[2], ...]
.transpose([*inverse_permutes, *list(range(3, mri_data.data.ndim))])
.copy()
)
return MRIData(data, affine)
data = data[:: flips[0], :: flips[1], :: flips[2], ...].transpose([*inverse_permutes, *list(range(3, data.ndim))]).copy()
return data, affine


def change_of_coordinates_map(orientation_in: str, orientation_out: str) -> np.ndarray:
Expand Down
5 changes: 2 additions & 3 deletions src/mritk/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def extract_single_volume(D: np.ndarray, frame_fg) -> MRIData:

A_dcm = dicom_standard_affine(frame_fg)
C = change_of_coordinates_map("LPS", "RAS")
mri = data_reorientation(MRIData(volume, C @ A_dcm))

return mri
data, affine = data_reorientation(volume, C @ A_dcm)
return MRIData(data=data, affine=affine)


def mixed_t1map(
Expand Down
72 changes: 72 additions & 0 deletions src/mritk/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
from pathlib import Path
from urllib.request import urlretrieve

import numpy as np
import numpy.typing as npt
import pandas as pd

from .data import MRIData, load_mri_data

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -75,6 +79,74 @@
}


class Segmentation(MRIData):
def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | None = None):
super().__init__(data, affine)
self.data = self.data.astype(int)
self.rois = np.unique(self.data[self.data > 0])
if lut is not None:
self.lut = lut
else:
self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois)

self._label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0]

@property
def num_rois(self) -> int:
return len(self.rois)

@property
def roi_labels(self) -> np.ndarray:
return self.rois

def get_roi_labels(self, rois: npt.NDArray[np.int_] | None = None) -> pd.DataFrame:
if rois is None:
rois = self.rois

if not np.isin(rois, self.rois).all():
raise ValueError("Some of the provided ROIs are not present in the segmentation.")

return self.lut.loc[self.lut.index.isin(rois), [self._label_name]].rename_axis("ROI").reset_index()

@classmethod
def from_file(
cls, filepath: Path | str, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None
) -> "Segmentation":
resolved_lut_path = resolve_lut_path(lut_path)
lut = read_lut(resolved_lut_path)
data, affine = load_mri_data(filepath, dtype=dtype, orient=orient)
return cls(data=data, affine=affine, lut=lut)


class FreeSurferSegmentation(Segmentation): ...


class ExtendedFreeSurferSegmentation(FreeSurferSegmentation):
def get_roi_labels(self, rois: npt.NDArray[np.int_] | None = None) -> pd.DataFrame:
rois = self.rois if rois is None else rois

freesurfer_labels = super().get_roi_labels(rois % 10000).rename(columns={"ROI": "FreeSurfer_ROI"})

tissue_type = self.get_tissue_type(rois)
return freesurfer_labels.merge(
tissue_type,
left_on="FreeSurfer_ROI",
right_on="FreeSurfer_ROI",
how="outer",
).drop(columns=["FreeSurfer_ROI"])[["ROI", self._label_name, "tissue_type"]]

def get_tissue_type(self, rois: npt.NDArray[np.int_] | None = None) -> pd.DataFrame:
rois = self.rois if rois is None else rois
tissue_types = pd.Series(
data=np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")),
index=rois,
name="tissue_type",
)
ret = pd.DataFrame(tissue_types, columns=["tissue_type"]).rename_axis("ROI").reset_index()
ret["FreeSurfer_ROI"] = ret["ROI"] % 10000
return ret


def default_segmentation_groups() -> dict[str, list[int]]:
"""
Returns the default grouping of FreeSurfer labels into brain regions.
Expand Down
7 changes: 7 additions & 0 deletions src/mritk/statistics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2026 Jørgen Riseth (jnriseth@gmail.com)
# Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no)
# Copyright (C) 2026 Simula Research Laboratory

from . import cli, compute_stats, stat_functions, utils

__all__ = ["utils", "compute_stats", "cli", "stat_functions"]
Loading
Loading