diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..a34a1b6d7 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -233,6 +233,7 @@ Available Datasets datasets/pyhealth.datasets.DREAMTDataset datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset + datasets/pyhealth.datasets.IBISleepDataset datasets/pyhealth.datasets.EHRShotDataset datasets/pyhealth.datasets.Support2Dataset datasets/pyhealth.datasets.BMDHSDataset diff --git a/docs/api/datasets/pyhealth.datasets.IBISleepDataset.rst b/docs/api/datasets/pyhealth.datasets.IBISleepDataset.rst new file mode 100644 index 000000000..01cf79013 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.IBISleepDataset.rst @@ -0,0 +1,15 @@ +pyhealth.datasets.IBISleepDataset +================================= + +Dataset for IBI-based sleep staging from DREAMT, SHHS, and MESA recordings. +Each subject's overnight recording is stored as a pre-processed NPZ file containing +a 25 Hz inter-beat-interval (IBI) time series and per-sample sleep stage labels. + +See ``examples/preprocess_dreamt_to_ibi.py``, ``examples/preprocess_shhs_to_ibi.py``, and +``examples/preprocess_mesa_to_ibi.py`` for scripts that convert raw EDF recordings to the +NPZ format expected by this dataset. + +.. autoclass:: pyhealth.datasets.IBISleepDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..025972183 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -189,6 +189,7 @@ API Reference models/pyhealth.models.JambaEHR models/pyhealth.models.ContraWR models/pyhealth.models.SparcNet + models/pyhealth.models.WatchSleepNet models/pyhealth.models.StageNet models/pyhealth.models.StageAttentionNet models/pyhealth.models.AdaCare diff --git a/docs/api/models/pyhealth.models.WatchSleepNet.rst b/docs/api/models/pyhealth.models.WatchSleepNet.rst new file mode 100644 index 000000000..3a0ee83c2 --- /dev/null +++ b/docs/api/models/pyhealth.models.WatchSleepNet.rst @@ -0,0 +1,15 @@ +pyhealth.models.WatchSleepNet +============================= + +WatchSleepNet: a ResNet → TCN → BiLSTM → Attention architecture for +IBI-based sleep staging from consumer wearable devices. + +.. autoclass:: pyhealth.models.watchsleepnet.ResidualBlock + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.WatchSleepNet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..b2ac27cd8 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -222,6 +222,7 @@ Available Tasks Sleep Staging (SleepEDF) Temple University EEG Tasks Sleep Staging v2 + Sleep Staging IBI Benchmark EHRShot ChestX-ray14 Binary Classification De-Identification NER diff --git a/docs/api/tasks/pyhealth.tasks.SleepStagingIBI.rst b/docs/api/tasks/pyhealth.tasks.SleepStagingIBI.rst new file mode 100644 index 000000000..38949ff04 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.SleepStagingIBI.rst @@ -0,0 +1,10 @@ +pyhealth.tasks.SleepStagingIBI +============================== + +Sleep staging task for IBI-based recordings (DREAMT, SHHS, MESA). +Supports 3-class (W / NREM / REM) and 5-class (W / N1 / N2 / N3 / REM) modes. + +.. autoclass:: pyhealth.tasks.sleep_staging_ibi.SleepStagingIBI + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/ibi_sleep_staging_ibi_watchsleepnet.py b/examples/ibi_sleep_staging_ibi_watchsleepnet.py new file mode 100644 index 000000000..9872edbc3 --- /dev/null +++ b/examples/ibi_sleep_staging_ibi_watchsleepnet.py @@ -0,0 +1,173 @@ +"""ibi_sleep_staging_ibi_watchsleepnet.py — WatchSleepNet two-phase transfer learning. + +Demonstrates pre-training on clinical IBI data (SHHS/MESA) then fine-tuning on +wearable IBI data (DREAMT), following the WatchSleepNet paper methodology. + +Quick start with synthetic data: + + python examples/ibi_sleep_staging_ibi_watchsleepnet.py --synthetic + +With real preprocessed data: + + python examples/ibi_sleep_staging_ibi_watchsleepnet.py \\ + --clinical_root ~/watchsleepnet_data/shhs \\ + --wearable_root ~/watchsleepnet_data/dreamt + +To produce the NPZ directories from raw recordings, run the preprocessing scripts: + + python examples/preprocess_shhs_to_ibi.py \\ + --src_dir /data/shhs/polysomnography \\ + --dst_dir ~/watchsleepnet_data/shhs \\ + --harmonized_csv /data/shhs/shhs-harmonized-dataset.csv + python examples/preprocess_mesa_to_ibi.py \\ + --src_dir /data/mesa/polysomnography \\ + --dst_dir ~/watchsleepnet_data/mesa \\ + --harmonized_csv /data/mesa/mesa-sleep-harmonized-dataset.csv + python examples/preprocess_dreamt_to_ibi.py \\ + --src_dir /data/dreamt/raw \\ + --dst_dir ~/watchsleepnet_data/dreamt \\ + --participant_info /data/dreamt/participant_info.csv +""" +from __future__ import annotations + +import argparse +import os +import tempfile + +import numpy as np +import torch + +from pyhealth.datasets import IBISleepDataset, get_dataloader +from pyhealth.datasets.splitter import split_by_patient +from pyhealth.models import WatchSleepNet +from pyhealth.tasks import SleepStagingIBI +from pyhealth.trainer import Trainer + + +def _make_synthetic_data( + root: str, + n_subjects: int, + epochs_per_subject: int, + seed: int = 0, +) -> None: + rng = np.random.default_rng(seed) + os.makedirs(root, exist_ok=True) + n = epochs_per_subject * 750 + for i in range(n_subjects): + np.savez( + os.path.join(root, f"S{i:04d}.npz"), + data=rng.random(n).astype(np.float32) * 0.5 + 0.6, + stages=rng.integers(0, 5, size=n).astype(np.int32), + fs=np.int64(25), + ahi=np.float32(rng.uniform(0.0, 30.0)), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="WatchSleepNet two-phase transfer learning" + ) + parser.add_argument( + "--clinical_root", + help="Directory of preprocessed clinical NPZ files (SHHS/MESA)", + ) + parser.add_argument( + "--wearable_root", + help="Directory of preprocessed wearable NPZ files (DREAMT)", + ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Generate synthetic data and run on it", + ) + parser.add_argument( + "--pretrain_epochs", + type=int, + default=30, + help="Epochs for phase 1 clinical pre-training", + ) + parser.add_argument( + "--finetune_epochs", + type=int, + default=30, + help="Epochs for phase 2 wearable fine-tuning", + ) + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_args() + + if args.synthetic: + _tmpdir = tempfile.mkdtemp(prefix="watchsleepnet_") + clinical_root = os.path.join(_tmpdir, "clinical") + wearable_root = os.path.join(_tmpdir, "wearable") + print(f"Generating synthetic data in {_tmpdir}") + _make_synthetic_data(clinical_root, n_subjects=40, epochs_per_subject=20) + _make_synthetic_data( + wearable_root, n_subjects=20, epochs_per_subject=15, seed=1 + ) + else: + if not args.clinical_root or not args.wearable_root: + parser.error( + "--clinical_root and --wearable_root are required (or use --synthetic)" + ) + clinical_root = os.path.expanduser(args.clinical_root) + wearable_root = os.path.expanduser(args.wearable_root) + for path, name in [ + (clinical_root, "--clinical_root"), + (wearable_root, "--wearable_root"), + ]: + if not os.path.isdir(path): + parser.error(f"{name}: directory not found: {path}") + + # step 1: load datasets + clinical_ds = IBISleepDataset(root=clinical_root, source="shhs") + wearable_ds = IBISleepDataset(root=wearable_root, source="dreamt") + + # step 2: set task + clinical_samples = clinical_ds.set_task(SleepStagingIBI(num_classes=5)) + wearable_samples = wearable_ds.set_task(SleepStagingIBI(num_classes=3)) + + # step 3: define model and dataloaders for phase 1 + model = WatchSleepNet(num_classes=5) + + train_clin, val_clin, _ = split_by_patient(clinical_samples, [0.7, 0.15, 0.15]) + train_loader = get_dataloader(train_clin, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_clin, batch_size=32, shuffle=False) + + # step 4: phase 1 — pre-train on clinical data + trainer = Trainer(model=model, device=args.device) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.pretrain_epochs, + optimizer_params={"lr": 1e-3}, + monitor="accuracy", + ) + + # step 5: phase 2 — replace head and fine-tune on wearable data + model_ft = WatchSleepNet(num_classes=3) + backbone_state = { + k: v for k, v in model.state_dict().items() if not k.startswith("fc.") + } + model_ft.load_state_dict(backbone_state, strict=False) + + train_wear, val_wear, test_wear = split_by_patient( + wearable_samples, [0.6, 0.2, 0.2] + ) + train_loader = get_dataloader(train_wear, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_wear, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_wear, batch_size=32, shuffle=False) + + trainer_ft = Trainer(model=model_ft, device=args.device) + trainer_ft.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.finetune_epochs, + optimizer_params={"lr": 1e-4}, + monitor="accuracy", + ) + + # step 6: evaluate + print(trainer_ft.evaluate(test_loader)) diff --git a/examples/preprocess_dreamt_to_ibi.py b/examples/preprocess_dreamt_to_ibi.py new file mode 100644 index 000000000..e39385f58 --- /dev/null +++ b/examples/preprocess_dreamt_to_ibi.py @@ -0,0 +1,186 @@ +"""preprocess_dreamt_to_ibi.py — Convert raw DREAMT PPG recordings to NPZ files. + +This is a standalone CLI script in ``examples/``, **not** part of the PyHealth API. +After running this script, the ``dst_dir`` can be passed directly as the ``root`` +argument to :class:`pyhealth.datasets.IBISleepDataset`. + +Required extra install (not in PyHealth core):: + + pip install neurokit2 + +Usage:: + + python examples/preprocess_dreamt_to_ibi.py \\ + --src_dir /path/to/DREAMT/raw \\ + --dst_dir /path/to/output/npz \\ + --participant_info /path/to/DREAMT/participant_info.csv + +DREAMT raw directory layout expected:: + + / + _PSG_df_updated.csv # 100 Hz BVP + stage columns + ... + +The ``participant_info.csv`` must contain at minimum: + - a subject-ID column (``Participant_ID`` or first column) + - an ``AHI`` column + +Output NPZ schema (one file per subject, saved as ``.npz``):: + + data : float32 (N,) IBI time series at 25 Hz + stages : int32 (N,) 0=W, 1=N1, 2=N2, 3=N3, 4=REM (sample-level) + fs : int64 () Always 25 + ahi : float32 () Apnea-Hypopnea Index (NaN if unavailable) +""" + +from __future__ import annotations + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + +_STAGE_MAP = {"W": 0, "N1": 1, "N2": 2, "N3": 3, "R": 4} +_DREAMT_FS = 100 # Hz of raw BVP signal +_TARGET_FS = 25 +_STRIDE = _DREAMT_FS // _TARGET_FS # 4 +_IBI_OUTLIER_S = 2.0 # zero out intervals >= this + + +def _extract_ibi_dreamt(bvp: np.ndarray, fs: int = _DREAMT_FS) -> np.ndarray: + """Return per-sample IBI array at *fs* Hz using neurokit2 PPG processing.""" + try: + import neurokit2 as nk # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "neurokit2 is required for DREAMT preprocessing. " + "Install it with: pip install neurokit2" + ) from exc + + signals, info = nk.ppg_process(bvp, sampling_rate=fs) + peaks = info["PPG_Peaks"] # sample indices of systolic peaks + + ibi = np.zeros(len(bvp), dtype=np.float32) + for i in range(1, len(peaks)): + interval_s = (peaks[i] - peaks[i - 1]) / fs + if interval_s >= _IBI_OUTLIER_S: + interval_s = 0.0 + ibi[peaks[i - 1] : peaks[i]] = interval_s + return ibi + + +def _encode_stages(stage_series: pd.Series) -> np.ndarray: + """Map DREAMT string labels → int32 (unknown → -1).""" + return stage_series.map(_STAGE_MAP).fillna(-1).astype(np.int32).to_numpy() + + +def _process_subject( + src_dir: str, + dst_dir: str, + sid: str, + ahi: float, +) -> bool: + """Process one subject. Returns True on success, False on skip/error.""" + out_path = Path(dst_dir) / f"{sid}.npz" + if out_path.exists(): + logger.info("Skipping %s — NPZ already exists", sid) + return True + + csv_candidates = list(Path(src_dir).glob(f"{sid}_PSG_df_updated.csv")) + if not csv_candidates: + logger.warning("No PSG CSV found for subject %s — skipping", sid) + return False + csv_path = csv_candidates[0] + + try: + df = pd.read_csv(csv_path) + except Exception as exc: # noqa: BLE001 + logger.warning("Cannot read %s: %s — skipping", csv_path, exc) + return False + + if "BVP" not in df.columns: + logger.warning("No BVP column in %s — skipping", csv_path) + return False + + stage_col = next( + (c for c in ("stage", "Stage", "sleep_stage", "Sleep_Stage") if c in df.columns), + None, + ) + if stage_col is None: + logger.warning("No stage column in %s — skipping", csv_path) + return False + + try: + ibi_100hz = _extract_ibi_dreamt(df["BVP"].to_numpy(dtype=np.float64)) + except Exception as exc: # noqa: BLE001 + logger.warning("IBI extraction failed for %s: %s — skipping", sid, exc) + return False + + # Stride-4 downsample: 100 Hz → 25 Hz + data_25hz = ibi_100hz[::_STRIDE].astype(np.float32) + stages_25hz = _encode_stages(df[stage_col])[::_STRIDE].astype(np.int32) + + # Align lengths + n = min(len(data_25hz), len(stages_25hz)) + data_25hz = data_25hz[:n] + stages_25hz = stages_25hz[:n] + + np.savez( + out_path, + data=data_25hz, + stages=stages_25hz, + fs=np.int64(_TARGET_FS), + ahi=np.float32(ahi), + ) + logger.info("Saved %s (%d samples, AHI=%.1f)", out_path, n, ahi) + return True + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert raw DREAMT PPG recordings to NPZ files for IBISleepDataset." + ) + parser.add_argument("--src_dir", required=True, help="Directory with raw DREAMT CSV files") + parser.add_argument("--dst_dir", required=True, help="Output directory for NPZ files (= IBISleepDataset root)") + parser.add_argument("--participant_info", required=True, help="Path to participant_info.csv with AHI column") + parser.add_argument( + "--log_level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging verbosity", + ) + parser.add_argument("--limit", type=int, default=None, help="Process at most N subjects") + args = parser.parse_args() + + logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s") + os.makedirs(args.dst_dir, exist_ok=True) + + info_df = pd.read_csv(args.participant_info) + # Accept 'Participant_ID' or first column as subject ID + id_col = "Participant_ID" if "Participant_ID" in info_df.columns else info_df.columns[0] + ahi_col = next((c for c in info_df.columns if c.upper() == "AHI"), None) + + if args.limit is not None: + info_df = info_df.head(args.limit) + + success = fail = skip = 0 + for _, row in info_df.iterrows(): + sid = str(row[id_col]) + ahi = float(row[ahi_col]) if ahi_col is not None else float("nan") + result = _process_subject(args.src_dir, args.dst_dir, sid, ahi) + if result: + success += 1 + else: + fail += 1 + + logger.info("Done. success=%d failed/skipped=%d", success, fail) + + +if __name__ == "__main__": + main() diff --git a/examples/preprocess_mesa_to_ibi.py b/examples/preprocess_mesa_to_ibi.py new file mode 100644 index 000000000..e2e843c1f --- /dev/null +++ b/examples/preprocess_mesa_to_ibi.py @@ -0,0 +1,253 @@ +"""preprocess_mesa_to_ibi.py — Convert raw MESA PPG recordings to NPZ files. + +This is a standalone CLI script in ``examples/``, **not** part of the PyHealth API. +After running this script, the ``dst_dir`` can be passed directly as the ``root`` +argument to :class:`pyhealth.datasets.IBISleepDataset` with ``source="mesa"``. + +Required extra installs (not in PyHealth core):: + + pip install neurokit2 mne + +``mne`` is already installed in most PyHealth environments; ``neurokit2`` is the +only additional dependency. + +Usage:: + + python examples/preprocess_mesa_to_ibi.py \\ + --src_dir /path/to/MESA/polysomnography \\ + --dst_dir /path/to/output/npz \\ + --harmonized_csv /path/to/mesa-sleep-harmonized-dataset.csv + +MESA raw directory layout expected (standard NSRR download):: + + / + edfs/ + mesa-sleep-.edf + annotations-events-profusion/ + mesa-sleep--profusion.xml + +The harmonized CSV must contain at minimum: + - ``mesaid`` column (integer) + - an AHI column (script tries several common names) + +Output NPZ schema (one file per subject, e.g. ``mesa-sleep-00001.npz``):: + + data : float32 (N,) IBI time series at 25 Hz + stages : int32 (N,) 0=W, 1=N1, 2=N2, 3=N3, 4=REM (sample-level) + fs : int64 () Always 25 + ahi : float32 () Apnea-Hypopnea Index (NaN if unavailable) +""" + +from __future__ import annotations + +import argparse +import logging +import os +import xml.etree.ElementTree as ET +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +from scipy.signal import resample_poly + +logger = logging.getLogger(__name__) + +_TARGET_FS = 25 + +# Profusion XML stage codes → unified 5-class scheme (4=N4→N3, 5=REM→4) +_STAGE_MAP = {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 4} + +_AHI_CANDIDATES = [ + "nsrr_ahi_hp3r_aasm15", "nsrr_ahi_hp3u", "nsrr_ahi_hp4u_aasm15", + "ahi_a0h3a", "ahi_a0h4a", "ahi_a0h3", "ahi_a0h4", "AHI", "ahi", +] +_PLETH_CANDIDATES = ["Pleth", "PLETH", "SpO2", "PPG"] + + +def _parse_profusion_xml(xml_path: Path) -> np.ndarray: + """Parse MESA profusion XML → per-epoch stage array (30 s epochs).""" + try: + tree = ET.parse(xml_path) + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Cannot parse {xml_path}: {exc}") from exc + + root = tree.getroot() + stages = [] + for elem in root.iter("SleepStage"): + try: + raw = int(elem.text) + except (TypeError, ValueError): + raw = -1 + stages.append(_STAGE_MAP.get(raw, -1)) + return np.array(stages, dtype=np.int32) + + +def _extract_ibi_ppg(ppg: np.ndarray, fs: float) -> np.ndarray: + """Return per-sample IBI array at *fs* Hz using neurokit2 PPG processing.""" + try: + import neurokit2 as nk # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "neurokit2 is required for MESA preprocessing. " + "Install it with: pip install neurokit2" + ) from exc + + signals, info = nk.ppg_process(ppg, sampling_rate=fs) + peaks = info["PPG_Peaks"] + + ibi = np.zeros(len(ppg), dtype=np.float32) + for i in range(1, len(peaks)): + interval_s = (peaks[i] - peaks[i - 1]) / fs + ibi[peaks[i - 1] : peaks[i]] = interval_s + return ibi + + +def _process_recording( + edf_path: Path, + xml_path: Optional[Path], + dst_dir: str, + ahi: float, +) -> bool: + """Process one MESA recording. Returns True on success.""" + import mne # noqa: PLC0415 + + sid = edf_path.stem # e.g. mesa-sleep-00001 + out_path = Path(dst_dir) / f"{sid}.npz" + if out_path.exists(): + logger.info("Skipping %s — NPZ already exists", sid) + return True + + try: + raw = mne.io.read_raw_edf(str(edf_path), preload=True, verbose=False) + except Exception as exc: # noqa: BLE001 + logger.warning("Cannot read EDF %s: %s — skipping", edf_path, exc) + return False + + pleth_ch = next((c for c in raw.ch_names if any(p in c for p in _PLETH_CANDIDATES)), None) + if pleth_ch is None: + logger.warning("No Pleth/PPG channel in %s — skipping", edf_path) + return False + + ppg_data, _ = raw[pleth_ch] + ppg = ppg_data[0] + fs_orig = raw.info["sfreq"] + + try: + ibi_orig = _extract_ibi_ppg(ppg, fs_orig) + except Exception as exc: # noqa: BLE001 + logger.warning("IBI extraction failed for %s: %s — skipping", sid, exc) + return False + + # Resample to TARGET_FS + gcd = int(np.gcd(int(_TARGET_FS), int(fs_orig))) + up = _TARGET_FS // gcd + down = int(fs_orig) // gcd + data_25hz = resample_poly(ibi_orig, up, down).astype(np.float32) + + n = len(data_25hz) + + # Parse stage annotations from profusion XML + if xml_path is not None and xml_path.exists(): + try: + epoch_stages = _parse_profusion_xml(xml_path) + except ValueError as exc: + logger.warning("%s — stage parse failed: %s; stages set to -1", sid, exc) + epoch_stages = np.full(n // (_TARGET_FS * 30) + 1, -1, dtype=np.int32) + else: + logger.warning("No profusion XML for %s — stages set to -1", sid) + epoch_stages = np.full(n // (_TARGET_FS * 30) + 1, -1, dtype=np.int32) + + # Expand epoch-level stages to sample-level at 25 Hz (30 s × 25 = 750 samples/epoch) + samples_per_epoch = _TARGET_FS * 30 + stages_25hz = np.full(n, -1, dtype=np.int32) + for ep_idx, stage in enumerate(epoch_stages): + start = ep_idx * samples_per_epoch + end = start + samples_per_epoch + if start >= n: + break + stages_25hz[start : min(end, n)] = stage + + np.savez( + out_path, + data=data_25hz, + stages=stages_25hz, + fs=np.int64(_TARGET_FS), + ahi=np.float32(ahi), + ) + logger.info("Saved %s (%d samples, AHI=%.1f)", out_path, n, ahi) + return True + + +def _find_ahi(info_df: pd.DataFrame, mesaid: int) -> float: + id_col = "mesaid" if "mesaid" in info_df.columns else "nsrrid" + row = info_df[info_df[id_col] == mesaid] + if row.empty: + return float("nan") + for col in _AHI_CANDIDATES: + if col in row.columns: + val = row.iloc[0][col] + if pd.notna(val): + return float(val) + return float("nan") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert raw MESA PPG recordings to NPZ files for IBISleepDataset." + ) + parser.add_argument("--src_dir", required=True, help="MESA polysomnography root (contains edfs/ and annotations/)") + parser.add_argument("--dst_dir", required=True, help="Output directory for NPZ files (= IBISleepDataset root)") + parser.add_argument("--harmonized_csv", required=True, help="MESA harmonized dataset CSV with mesaid and AHI columns") + parser.add_argument("--workers", type=int, default=1, help="Number of parallel worker processes") + parser.add_argument( + "--log_level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging verbosity", + ) + parser.add_argument("--limit", type=int, default=None, help="Process at most N recordings") + args = parser.parse_args() + + logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s") + os.makedirs(args.dst_dir, exist_ok=True) + + info_df = pd.read_csv(args.harmonized_csv) + + src = Path(args.src_dir) + edf_paths = sorted((src / "edfs").glob("mesa-sleep-*.edf")) + if args.limit is not None: + edf_paths = list(edf_paths)[: args.limit] + + def _args_for(edf_path: Path): + sid = edf_path.stem # mesa-sleep-00001 + try: + mesaid = int(sid.split("-")[-1]) + except ValueError: + mesaid = -1 + xml_path = src / "annotations-events-profusion" / f"{sid}-profusion.xml" + ahi = _find_ahi(info_df, mesaid) + return edf_path, xml_path if xml_path.exists() else None, args.dst_dir, ahi + + success = fail = 0 + if args.workers > 1: + with ProcessPoolExecutor(max_workers=args.workers) as pool: + futures = {pool.submit(_process_recording, *_args_for(p)): p for p in edf_paths} + for future in as_completed(futures): + if future.result(): + success += 1 + else: + fail += 1 + else: + for edf_path in edf_paths: + if _process_recording(*_args_for(edf_path)): + success += 1 + else: + fail += 1 + + logger.info("Done. success=%d failed/skipped=%d", success, fail) + + +if __name__ == "__main__": + main() diff --git a/examples/preprocess_shhs_to_ibi.py b/examples/preprocess_shhs_to_ibi.py new file mode 100644 index 000000000..d20a5e76f --- /dev/null +++ b/examples/preprocess_shhs_to_ibi.py @@ -0,0 +1,268 @@ +"""preprocess_shhs_to_ibi.py — Convert raw SHHS ECG recordings to NPZ files. + +This is a standalone CLI script in ``examples/``, **not** part of the PyHealth API. +After running this script, the ``dst_dir`` can be passed directly as the ``root`` +argument to :class:`pyhealth.datasets.IBISleepDataset` with ``source="shhs"``. + +Required extra installs (not in PyHealth core):: + + pip install biosppy mne + +``mne`` is already installed in most PyHealth environments; ``biosppy`` is the +only additional dependency. + +Usage:: + + python examples/preprocess_shhs_to_ibi.py \\ + --src_dir /path/to/SHHS/polysomnography \\ + --dst_dir /path/to/output/npz \\ + --harmonized_csv /path/to/shhs-harmonized-dataset.csv + +SHHS raw directory layout expected (standard NSRR download):: + + / + edfs/ + shhs1/ + shhs1-.edf + shhs2/ + shhs2-.edf + annotations-events-profusion/ + shhs1/ + shhs1--profusion.xml + shhs2/ + shhs2--profusion.xml + +The harmonized CSV must contain at minimum: + - ``nsrrid`` column (integer) + - ``ahi_a0h3a`` (or similar AHI column; script tries several common names) + +Output NPZ schema (one file per recording, e.g. ``shhs1-200001.npz``):: + + data : float32 (N,) IBI time series at 25 Hz + stages : int32 (N,) 0=W, 1=N1, 2=N2, 3=N3, 4=REM (sample-level) + fs : int64 () Always 25 + ahi : float32 () Apnea-Hypopnea Index (NaN if unavailable) + +Stage remapping applied: + - SHHS annotation 4 (N4) → 3 (merged into N3) + - SHHS annotation 5 (REM) → 4 + +Excluded subjects: shhs1-204822 (known bad recording). +""" + +from __future__ import annotations + +import argparse +import logging +import os +import xml.etree.ElementTree as ET +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +from scipy.signal import resample_poly + +logger = logging.getLogger(__name__) + +_TARGET_FS = 25 +_EXCLUDED = {"shhs1-204822"} + +# SHHS profusion → unified stage map (4→3, 5→4, all others kept or -1) +_STAGE_MAP = {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 4} +# 0=W, 1=N1, 2=N2, 3=N3, 4=N4→N3, 5=REM→4 + +_AHI_CANDIDATES = ["ahi_a0h3a", "ahi_a0h4a", "ahi_a0h3", "ahi_a0h4", "AHI", "ahi"] + + +def _extract_ibi_ecg(ecg: np.ndarray, fs: float) -> np.ndarray: + """Return per-sample IBI array at *fs* Hz using biosppy ECG processing.""" + try: + from biosppy.signals.ecg import ecg as bsp_ecg # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "biosppy is required for SHHS preprocessing. " + "Install it with: pip install biosppy" + ) from exc + + out = bsp_ecg(signal=ecg, sampling_rate=fs, show=False) + rpeaks = out["rpeaks"] # sample indices + + ibi = np.zeros(len(ecg), dtype=np.float32) + for i in range(1, len(rpeaks)): + interval_s = (rpeaks[i] - rpeaks[i - 1]) / fs + ibi[rpeaks[i - 1] : rpeaks[i]] = interval_s + return ibi + + +def _parse_profusion_xml(xml_path: Path) -> np.ndarray: + """Parse SHHS profusion XML → per-epoch stage array (30 s epochs).""" + try: + tree = ET.parse(xml_path) + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Cannot parse {xml_path}: {exc}") from exc + + root = tree.getroot() + stages = [] + for elem in root.iter("SleepStage"): + try: + raw = int(elem.text) + except (TypeError, ValueError): + raw = -1 + stages.append(_STAGE_MAP.get(raw, -1)) + return np.array(stages, dtype=np.int32) + + +def _process_recording( + edf_path: Path, + xml_path: Optional[Path], + dst_dir: str, + ahi: float, +) -> bool: + """Process one SHHS recording. Returns True on success.""" + import mne # noqa: PLC0415 + + sid = edf_path.stem # e.g. shhs1-200001 + if sid in _EXCLUDED: + logger.info("Skipping excluded subject %s", sid) + return True + + out_path = Path(dst_dir) / f"{sid}.npz" + if out_path.exists(): + logger.info("Skipping %s — NPZ already exists", sid) + return True + + try: + raw = mne.io.read_raw_edf(str(edf_path), preload=True, verbose=False) + except Exception as exc: # noqa: BLE001 + logger.warning("Cannot read EDF %s: %s — skipping", edf_path, exc) + return False + + ecg_candidates = [c for c in raw.ch_names if "ECG" in c.upper()] + if not ecg_candidates: + logger.warning("No ECG channel in %s — skipping", edf_path) + return False + + ecg_ch = ecg_candidates[0] + ecg_data, times = raw[ecg_ch] + ecg = ecg_data[0] + fs_orig = raw.info["sfreq"] + + try: + ibi_orig = _extract_ibi_ecg(ecg, fs_orig) + except Exception as exc: # noqa: BLE001 + logger.warning("IBI extraction failed for %s: %s — skipping", sid, exc) + return False + + # Resample IBI to TARGET_FS using rational resampling + gcd = int(np.gcd(int(_TARGET_FS), int(fs_orig))) + up = _TARGET_FS // gcd + down = int(fs_orig) // gcd + data_25hz = resample_poly(ibi_orig, up, down).astype(np.float32) + + # Parse stage annotations (per 30-s epoch → expand to samples) + if xml_path is not None and xml_path.exists(): + try: + epoch_stages = _parse_profusion_xml(xml_path) + except ValueError as exc: + logger.warning("%s — skipping stage parse: %s", sid, exc) + epoch_stages = np.full(len(data_25hz) // (_TARGET_FS * 30) + 1, -1, dtype=np.int32) + else: + logger.warning("No profusion XML for %s — stages set to -1", sid) + epoch_stages = np.full(len(data_25hz) // (_TARGET_FS * 30) + 1, -1, dtype=np.int32) + + # Expand epoch-level stages to sample-level at 25 Hz (30 s × 25 = 750 samples/epoch) + samples_per_epoch = _TARGET_FS * 30 + n_samples = len(data_25hz) + stages_25hz = np.full(n_samples, -1, dtype=np.int32) + for ep_idx, stage in enumerate(epoch_stages): + start = ep_idx * samples_per_epoch + end = start + samples_per_epoch + if start >= n_samples: + break + stages_25hz[start : min(end, n_samples)] = stage + + n = min(len(data_25hz), len(stages_25hz)) + np.savez( + out_path, + data=data_25hz[:n], + stages=stages_25hz[:n], + fs=np.int64(_TARGET_FS), + ahi=np.float32(ahi), + ) + logger.info("Saved %s (%d samples)", out_path, n) + return True + + +def _find_ahi(info_df: pd.DataFrame, nsrrid: int) -> float: + row = info_df[info_df["nsrrid"] == nsrrid] + if row.empty: + return float("nan") + for col in _AHI_CANDIDATES: + if col in row.columns: + val = row.iloc[0][col] + if pd.notna(val): + return float(val) + return float("nan") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert raw SHHS ECG recordings to NPZ files for IBISleepDataset." + ) + parser.add_argument("--src_dir", required=True, help="SHHS polysomnography root (contains edfs/ and annotations-events-profusion/)") + parser.add_argument("--dst_dir", required=True, help="Output directory for NPZ files (= IBISleepDataset root)") + parser.add_argument("--harmonized_csv", required=True, help="SHHS harmonized dataset CSV with nsrrid and AHI columns") + parser.add_argument("--workers", type=int, default=1, help="Number of parallel worker processes") + parser.add_argument( + "--log_level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging verbosity", + ) + parser.add_argument("--limit", type=int, default=None, help="Process at most N recordings") + args = parser.parse_args() + + logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s") + os.makedirs(args.dst_dir, exist_ok=True) + + info_df = pd.read_csv(args.harmonized_csv) + + src = Path(args.src_dir) + edf_paths = sorted( + list((src / "edfs" / "shhs1").glob("*.edf")) + + list((src / "edfs" / "shhs2").glob("*.edf")) + ) + if args.limit is not None: + edf_paths = edf_paths[: args.limit] + + def _args_for(edf_path: Path): + sid = edf_path.stem + visit = "shhs1" if sid.startswith("shhs1") else "shhs2" + nsrrid = int(sid.split("-")[1]) if "-" in sid else -1 + xml_path = src / "annotations-events-profusion" / visit / f"{sid}-profusion.xml" + ahi = _find_ahi(info_df, nsrrid) + return edf_path, xml_path if xml_path.exists() else None, args.dst_dir, ahi + + success = fail = 0 + if args.workers > 1: + with ProcessPoolExecutor(max_workers=args.workers) as pool: + futures = {pool.submit(_process_recording, *_args_for(p)): p for p in edf_paths} + for future in as_completed(futures): + if future.result(): + success += 1 + else: + fail += 1 + else: + for edf_path in edf_paths: + if _process_recording(*_args_for(edf_path)): + success += 1 + else: + fail += 1 + + logger.info("Done. success=%d failed/skipped=%d", success, fail) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..1433c5a28 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -64,6 +64,7 @@ def __init__(self, *args, **kwargs): from .physionet_deid import PhysioNetDeIDDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset +from .ibi_sleep import IBISleepDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset diff --git a/pyhealth/datasets/configs/ibi_sleep.yaml b/pyhealth/datasets/configs/ibi_sleep.yaml new file mode 100644 index 000000000..4c39adc52 --- /dev/null +++ b/pyhealth/datasets/configs/ibi_sleep.yaml @@ -0,0 +1,9 @@ +version: "1.0" +tables: + ibi_sleep: + file_path: "ibi_sleep-metadata.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "npz_path" + - "ahi" diff --git a/pyhealth/datasets/ibi_sleep.py b/pyhealth/datasets/ibi_sleep.py new file mode 100644 index 000000000..df560362c --- /dev/null +++ b/pyhealth/datasets/ibi_sleep.py @@ -0,0 +1,140 @@ +"""IBISleepDataset: PyHealth dataset for IBI-based sleep staging. + +Loads preprocessed NPZ files produced by the preprocess_dreamt_to_ibi.py, +preprocess_shhs_to_ibi.py, or preprocess_mesa_to_ibi.py scripts in examples/. +Pass the dst_dir from those scripts as the root argument here. +""" + +import logging +import os +from typing import Literal, Optional + +import numpy as np +import pandas as pd + +from pyhealth.datasets import BaseDataset +from pyhealth.tasks.sleep_staging_ibi import SleepStagingIBI + +logger = logging.getLogger(__name__) + + +class IBISleepDataset(BaseDataset): + """Dataset for IBI-based sleep staging from DREAMT, SHHS, or MESA. + + Loads preprocessed NPZ files where each file contains the IBI time series, + per-sample sleep stage labels, sampling rate, and AHI for one subject. + Use one of the preprocessing scripts in examples/ to produce these files + before constructing this dataset. + + Args: + root: Directory containing ``*.npz`` files and where + ``ibi_sleep-metadata.csv`` will be written. + source: Dataset origin — one of ``"dreamt"``, ``"shhs"``, or + ``"mesa"``. Affects documentation context only; loading + behavior is identical for all three. + dataset_name: Optional name override. Defaults to the class name. + config_path: Path to YAML schema config. Defaults to + ``pyhealth/datasets/configs/ibi_sleep.yaml``. + dev: If ``True``, limits to the first 1000 patients (inherited from + ``BaseDataset``). + + Raises: + FileNotFoundError: If ``root`` does not exist or contains no + readable ``.npz`` files. + + Examples: + >>> from pyhealth.datasets import IBISleepDataset + >>> dataset = IBISleepDataset( + ... root="/path/to/dreamt_npz", + ... source="dreamt", + ... ) + >>> sample_ds = dataset.set_task() + >>> sample_ds[0] + { + 'patient_id': 'S002', + 'signal': array([...], + dtype=float32), + 'label': 1, + 'ahi': 5.2, + } + """ + + def __init__( + self, + root: str, + source: Literal["dreamt", "shhs", "mesa"], + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + dev: bool = False, + **kwargs, + ) -> None: + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "ibi_sleep.yaml" + ) + + metadata_path = os.path.join(root, "ibi_sleep-metadata.csv") + if not os.path.exists(metadata_path): + self.prepare_metadata(root) + + self.source = source + super().__init__( + root=root, + tables=["ibi_sleep"], + dataset_name=dataset_name or "IBISleepDataset", + config_path=config_path, + dev=dev, + **kwargs, + ) + + def prepare_metadata(self, root: str) -> None: + """Scan root for NPZ files and write ibi_sleep-metadata.csv. + + Args: + root: Directory to scan for ``*.npz`` files. + + Raises: + FileNotFoundError: If no readable ``.npz`` files are found. + """ + npz_paths = sorted( + os.path.join(root, f) + for f in os.listdir(root) + if f.endswith(".npz") + ) + + rows = [] + for npz_path in npz_paths: + try: + data = np.load(npz_path, allow_pickle=False) + except Exception as exc: + logger.warning("Skipping unreadable NPZ file %s: %s", npz_path, exc) + continue + + ahi = float(data["ahi"]) if "ahi" in data else float("nan") + patient_id = os.path.splitext(os.path.basename(npz_path))[0] + rows.append( + { + "patient_id": patient_id, + "npz_path": os.path.abspath(npz_path), + "ahi": ahi, + } + ) + + if not rows: + raise FileNotFoundError( + f"No readable .npz files found in '{root}'. " + "Run one of the preprocess_*.py scripts in examples/ first." + ) + + df = pd.DataFrame(rows, columns=["patient_id", "npz_path", "ahi"]) + df.to_csv(os.path.join(root, "ibi_sleep-metadata.csv"), index=False) + logger.info("Wrote ibi_sleep-metadata.csv with %d subjects.", len(rows)) + + @property + def default_task(self) -> SleepStagingIBI: + """Returns the default task for this dataset. + + Returns: + SleepStagingIBI: Default task instance with ``num_classes=3``. + """ + return SleepStagingIBI() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..5f5d8956c 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest +from .watchsleepnet import WatchSleepNet, ResidualBlock diff --git a/pyhealth/models/watchsleepnet.py b/pyhealth/models/watchsleepnet.py new file mode 100644 index 000000000..b5e667926 --- /dev/null +++ b/pyhealth/models/watchsleepnet.py @@ -0,0 +1,203 @@ +"""WatchSleepNet: ResNet + TCN + BiLSTM + Attention model for IBI sleep staging. + +Reference: + Wang et al. (2025). WatchSleepNet: A Scalable Deep Learning Model for + Wearable Sleep Staging. CHIL 2025, PMLR 287:1-20. + https://proceedings.mlr.press/v287/wang25a.html +""" + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from pyhealth.models import BaseModel + + +class ResidualBlock(nn.Module): + """1D residual block with two Conv(k=5) layers and optional downsampling shortcut. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + stride: Convolution stride. Default ``1``. + + Examples: + >>> block = ResidualBlock(64, 128, stride=2) + >>> x = torch.randn(4, 64, 100) + >>> block(x).shape + torch.Size([4, 128, 50]) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + ) -> None: + super().__init__() + self.conv1 = nn.Conv1d( + in_channels, + out_channels, + kernel_size=5, + stride=stride, + padding=2, + bias=False, + ) + self.bn1 = nn.BatchNorm1d(out_channels) + self.conv2 = nn.Conv1d( + out_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False + ) + self.bn2 = nn.BatchNorm1d(out_channels) + + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm1d(out_channels), + ) + else: + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor of shape ``(B, in_channels, L)``. + + Returns: + Output tensor of shape ``(B, out_channels, L // stride)``. + """ + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + return F.relu(out + self.shortcut(x)) + + +class WatchSleepNet(BaseModel): + """WatchSleepNet for IBI-based sleep staging. + + Architecture: 4-block ResNet → dilated TCN → BiLSTM → + Multi-head Attention → global average pooling → Linear classifier. + + Args: + dataset: Optional ``SampleDataset`` (passed to ``BaseModel``). + num_classes: Number of output sleep stage classes. Default ``5``. + hidden_dim: Feature dimension throughout the network. Must equal + ``2 * lstm_hidden``. Default ``256``. + lstm_hidden: Hidden size per direction in the BiLSTM. + Default ``128``. + attn_heads: Number of attention heads. Default ``8``. + **kwargs: Forwarded to ``BaseModel``. + + Raises: + ValueError: If ``2 * lstm_hidden != hidden_dim``. + + Examples: + >>> model = WatchSleepNet(num_classes=3) + >>> signal = torch.randn(4, 750) + >>> out = model(signal=signal) + >>> out["y_prob"].shape + torch.Size([4, 3]) + """ + + def __init__( + self, + dataset=None, + num_classes: int = 5, + hidden_dim: int = 256, + lstm_hidden: int = 128, + attn_heads: int = 8, + **kwargs, + ) -> None: + if 2 * lstm_hidden != hidden_dim: + raise ValueError( + f"BiLSTM output size constraint violated: " + f"2 * lstm_hidden ({2 * lstm_hidden}) != hidden_dim ({hidden_dim}). " + "Set lstm_hidden = hidden_dim // 2." + ) + super().__init__(dataset) + self.mode = "multiclass" + + self.num_classes = num_classes + self.hidden_dim = hidden_dim + + self.resnet = nn.Sequential( + ResidualBlock(1, 64, stride=2), + ResidualBlock(64, 128, stride=2), + ResidualBlock(128, 256, stride=2), + ResidualBlock(256, hidden_dim, stride=2), + ) + + # Dilated temporal convolutional layer + self.tcn = nn.Sequential( + nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, dilation=2, padding=4), + nn.ReLU(), + ) + + self.lstm = nn.LSTM( + input_size=hidden_dim, + hidden_size=lstm_hidden, + bidirectional=True, + batch_first=True, + ) + + self.attention = nn.MultiheadAttention( + embed_dim=hidden_dim, + num_heads=attn_heads, + batch_first=True, + ) + + self.fc = nn.Linear(hidden_dim, num_classes) + + def forward( + self, + signal: Tensor, + label: Optional[Tensor] = None, + **kwargs, + ) -> Dict[str, Tensor]: + """Forward pass. + + Args: + signal: IBI epoch batch of shape ``(B, 750)`` float32. + label: Ground-truth class indices of shape ``(B,)`` int64. + Optional. When provided, loss is computed. + **kwargs: Ignored (allows dict-unpacking from DataLoader batches). + + Returns: + Dict with keys: + + - ``"loss"``: Scalar CrossEntropyLoss, or ``0.0`` if no label. + - ``"y_prob"``: Softmax probabilities ``(B, num_classes)``. + - ``"y_true"``: ``label`` passed through, or ``None``. + + Raises: + ValueError: If ``signal.shape[-1] != 750``. + """ + if signal.shape[-1] != 750: + raise ValueError( + f"Expected signal length 750, got {signal.shape[-1]}." + ) + + x = signal.unsqueeze(1) # (B, 1, 750) + x = self.resnet(x) # (B, hidden_dim, 47) + x = self.tcn(x) # (B, hidden_dim, 47) + x = x.transpose(1, 2) # (B, 47, hidden_dim) + x, _ = self.lstm(x) # (B, 47, hidden_dim) + x, _ = self.attention(x, x, x) # (B, 47, hidden_dim) + x = x.mean(dim=1) # (B, hidden_dim) + logits = self.fc(x) # (B, num_classes) + + y_prob = F.softmax(logits, dim=-1) + if label is not None: + loss = F.cross_entropy(logits, label.long()) + else: + loss = torch.tensor(0.0, device=signal.device) + + return {"loss": loss, "y_prob": y_prob, "y_true": label} diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..fc7bd9155 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -57,6 +57,7 @@ sleep_staging_shhs_fn, sleep_staging_sleepedf_fn, ) +from .sleep_staging_ibi import SleepStagingIBI from .sleep_staging_v2 import SleepStagingSleepEDF from .temple_university_EEG_tasks import ( EEGEventsTUEV, diff --git a/pyhealth/tasks/sleep_staging_ibi.py b/pyhealth/tasks/sleep_staging_ibi.py new file mode 100644 index 000000000..9236d19b8 --- /dev/null +++ b/pyhealth/tasks/sleep_staging_ibi.py @@ -0,0 +1,121 @@ +"""SleepStagingIBI: PyHealth task for IBI-based sleep staging. + +Converts IBISleepDataset patient records into per-epoch sample dicts +for 3-class (Wake / NREM / REM) or 5-class (W / N1 / N2 / N3 / REM) +sleep staging using 30-second IBI signal epochs at 25 Hz. +""" + +from typing import Any, Dict, List, Literal + +import numpy as np + +from pyhealth.tasks import BaseTask + +_SAMPLES_PER_EPOCH: int = 750 +_MAX_EPOCHS: int = 1_100 +_LABEL_MAP_3CLASS: Dict[int, int] = {0: 0, 1: 1, 2: 1, 3: 1, 4: 2} +_LABEL_MAP_5CLASS: Dict[int, int] = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + +class SleepStagingIBI(BaseTask): + """Multi-class sleep staging task for IBI signals from IBISleepDataset. + + Each 30-second epoch of the IBI time series (750 samples at 25 Hz) is + mapped to a single sleep stage label. Supports 3-class and 5-class label + spaces. + + Attributes: + task_name (str): ``"SleepStagingIBI"`` + input_schema (Dict[str, str]): ``{"signal": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "multiclass"}`` + + Args: + num_classes: Label granularity. ``3`` → Wake/NREM/REM; + ``5`` → W/N1/N2/N3/REM. Default ``3``. + max_epochs: Maximum epochs to return per subject. Default ``1100``. + + Examples: + >>> from pyhealth.tasks import SleepStagingIBI + >>> task = SleepStagingIBI(num_classes=3) + >>> task.task_name + 'SleepStagingIBI' + """ + + task_name: str = "SleepStagingIBI" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __init__( + self, + num_classes: Literal[3, 5] = 3, + max_epochs: int = _MAX_EPOCHS, + ) -> None: + self.num_classes = num_classes + self.max_epochs = max_epochs + self._label_map = _LABEL_MAP_3CLASS if num_classes == 3 else _LABEL_MAP_5CLASS + super().__init__() + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Convert a patient's IBI record into per-epoch sample dicts. + + Args: + patient: A ``Patient`` object from ``IBISleepDataset``. + Each event exposes ``event.npz_path`` and ``event.ahi``. + + Returns: + List of sample dicts, one per valid epoch:: + + { + "patient_id": str, + "signal": np.ndarray, # float32, shape (750,) + "label": int, # mapped sleep stage + "ahi": float, # may be NaN + } + + Returns ``[]`` if the NPZ contains fewer than 750 samples. + + Raises: + ValueError: If ``fs != 25`` in the NPZ file. + """ + pid = patient.patient_id + samples: List[Dict[str, Any]] = [] + + for event in patient.get_events(): + npz = np.load(event.npz_path, allow_pickle=False) + signal_data = npz["data"].astype(np.float32) + stages = npz["stages"].astype(np.int32) + fs = int(npz["fs"]) + ahi_val = getattr(event, "ahi", None) + ahi = float(ahi_val) if ahi_val is not None else float("nan") + + if fs != 25: + raise ValueError( + f"Expected fs=25, got fs={fs} in {event.npz_path}" + ) + + n_samples = len(signal_data) + if n_samples < _SAMPLES_PER_EPOCH: + return [] + + n_epochs = n_samples // _SAMPLES_PER_EPOCH + epochs = signal_data[: n_epochs * _SAMPLES_PER_EPOCH].reshape( + n_epochs, _SAMPLES_PER_EPOCH + ) + + for i in range(n_epochs): + if len(samples) >= self.max_epochs: + break + raw_label = int(stages[i * _SAMPLES_PER_EPOCH]) + mapped = self._label_map.get(raw_label) + if mapped is None: + continue + samples.append( + { + "patient_id": pid, + "signal": epochs[i], + "label": mapped, + "ahi": ahi, + } + ) + + return samples diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..b3b2518bb --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,20 @@ +"""conftest.py: stub optional heavy dependencies unavailable in this environment.""" + +import sys +import types + +_STUBS = [ + "mamba_ssm", + "mamba_ssm.modules", + "mamba_ssm.modules.mamba_simple", + "linear_attention_transformer", + "ogb", + "ogb.graphproppred", + "ogb.graphproppred.mol_encoder", +] + +for _name in _STUBS: + if _name not in sys.modules: + _mod = types.ModuleType(_name) + _mod.__path__ = [] # mark as package so submodule lookups work + sys.modules[_name] = _mod diff --git a/tests/core/test_ibi_sleep_dataset.py b/tests/core/test_ibi_sleep_dataset.py new file mode 100644 index 000000000..9f4eb31ab --- /dev/null +++ b/tests/core/test_ibi_sleep_dataset.py @@ -0,0 +1,188 @@ +"""Tests for IBISleepDataset and SleepStagingIBI integration.""" + +import os +import sys +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import pandas as pd + +from pyhealth.datasets import IBISleepDataset +from pyhealth.tasks import SleepStagingIBI + +requires_py312 = unittest.skipIf( + sys.version_info < (3, 12), + reason="BaseDataset.set_task uses itertools.batched (Python 3.12+)", +) + + +def _write_npz( + directory: str, + name: str, + n_epochs: int = 3, + ahi: float = 5.0, + fs: int = 25, + include_ahi_key: bool = True, +) -> str: + n = n_epochs * 750 + rng = np.random.default_rng(abs(hash(name)) % (2**31)) + data = rng.random(n).astype(np.float32) + stages = rng.integers(0, 5, size=n).astype(np.int32) + path = os.path.join(directory, f"{name}.npz") + kwargs = dict(data=data, stages=stages, fs=np.int64(fs)) + if include_ahi_key: + kwargs["ahi"] = np.float32(ahi) + np.savez(path, **kwargs) + return path + + +def _make_dataset( + tmp_dir: str, n: int = 3, source: str = "dreamt", **kwargs +) -> IBISleepDataset: + for i in range(n): + _write_npz(tmp_dir, f"S{i:03d}") + return IBISleepDataset(root=tmp_dir, source=source, **kwargs) + + +def _read_meta(directory: str) -> pd.DataFrame: + return pd.read_csv(os.path.join(directory, "ibi_sleep-metadata.csv")) + + +class TestIBISleepDataset(unittest.TestCase): + """Tests for IBISleepDataset. + + Tests that need BaseDataset's lazy patient-loading (unique_patient_ids, + get_patient) share a single dataset instance via setUpClass so the ~2s + load cost is paid once. Tests that only verify metadata CSV output call + prepare_metadata() indirectly through the constructor and read the CSV + directly, avoiding the patient-load entirely. + """ + + @classmethod + def setUpClass(cls): + cls._cls_tmpdir = tempfile.TemporaryDirectory() + cls._cls_path = cls._cls_tmpdir.name + for i in range(3): + _write_npz(cls._cls_path, f"S{i:03d}") + cls._ds = IBISleepDataset(root=cls._cls_path, source="dreamt") + cls._pids = list(cls._ds.unique_patient_ids) + + @classmethod + def tearDownClass(cls): + cls._cls_tmpdir.cleanup() + + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.tmp_path = Path(self._tmpdir.name) + + def tearDown(self): + self._tmpdir.cleanup() + + def test_load_dreamt_source(self): + self.assertEqual(self._ds.source, "dreamt") + self.assertTrue((Path(self._cls_path) / "ibi_sleep-metadata.csv").exists()) + + def test_load_shhs_source(self): + _write_npz(str(self.tmp_path), "shhs1-200001") + _write_npz(str(self.tmp_path), "shhs2-300001") + ds = IBISleepDataset(root=str(self.tmp_path), source="shhs") + self.assertEqual(ds.source, "shhs") + self.assertEqual(len(_read_meta(str(self.tmp_path))), 2) + + def test_load_mesa_source(self): + for i in range(4): + _write_npz(str(self.tmp_path), f"mesa-sleep-{i:05d}") + ds = IBISleepDataset(root=str(self.tmp_path), source="mesa") + self.assertEqual(ds.source, "mesa") + self.assertEqual(len(_read_meta(str(self.tmp_path))), 4) + + def test_patient_ids(self): + names = ["Alpha", "Beta", "Gamma"] + for name in names: + _write_npz(str(self.tmp_path), name) + IBISleepDataset(root=str(self.tmp_path), source="dreamt") + self.assertEqual(set(_read_meta(str(self.tmp_path))["patient_id"]), set(names)) + + def test_getitem_keys(self): + patient = self._ds.get_patient(self._pids[0]) + events = patient.get_events() + self.assertGreaterEqual(len(events), 1) + event = events[0] + self.assertTrue(hasattr(event, "npz_path")) + self.assertTrue(hasattr(event, "ahi")) + + def test_ahi_nan_passes_through(self): + _write_npz(str(self.tmp_path), "nan_subject", ahi=float("nan")) + IBISleepDataset(root=str(self.tmp_path), source="dreamt") + df = _read_meta(str(self.tmp_path)) + self.assertTrue(df.loc[df["patient_id"] == "nan_subject", "ahi"].isna().all()) + + def test_dev_mode(self): + ds = _make_dataset(str(self.tmp_path), n=5, dev=True) + self.assertTrue(ds.dev) + + def test_empty_dir_raises(self): + with self.assertRaises(FileNotFoundError): + IBISleepDataset(root=str(self.tmp_path), source="dreamt") + + def test_missing_dir_raises(self): + with self.assertRaises((FileNotFoundError, OSError)): + IBISleepDataset(root="/nonexistent/path/xyz", source="dreamt") + + def test_corrupt_npz_skipped(self): + _write_npz(str(self.tmp_path), "good_subject") + Path(os.path.join(str(self.tmp_path), "corrupt_subject.npz")).write_bytes( + b"not a valid npz file" + ) + IBISleepDataset(root=str(self.tmp_path), source="dreamt") + df = _read_meta(str(self.tmp_path)) + self.assertEqual(len(df), 1) + self.assertIn("good_subject", df["patient_id"].values) + + def test_missing_ahi_key_stores_nan(self): + _write_npz(str(self.tmp_path), "no_ahi", include_ahi_key=False) + IBISleepDataset(root=str(self.tmp_path), source="dreamt") + df = _read_meta(str(self.tmp_path)) + self.assertTrue(df.loc[df["patient_id"] == "no_ahi", "ahi"].isna().all()) + + def test_default_task(self): + self.assertIsInstance(self._ds.default_task, SleepStagingIBI) + + +class TestIBISleepDatasetSetTask(unittest.TestCase): + + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.tmp_path = Path(self._tmpdir.name) + + def tearDown(self): + self._tmpdir.cleanup() + + @requires_py312 + def test_set_task_3class(self): + ds = _make_dataset(str(self.tmp_path), n=2) + sample_ds = ds.set_task() + labels = [sample_ds[i]["label"] for i in range(len(sample_ds))] + self.assertTrue(all(lbl in {0, 1, 2} for lbl in labels)) + + @requires_py312 + def test_set_task_5class(self): + ds = _make_dataset(str(self.tmp_path), n=2) + sample_ds = ds.set_task(SleepStagingIBI(num_classes=5)) + labels = [sample_ds[i]["label"] for i in range(len(sample_ds))] + self.assertTrue(all(lbl in {0, 1, 2, 3, 4} for lbl in labels)) + + @requires_py312 + def test_set_task_signal_key(self): + ds = _make_dataset(str(self.tmp_path), n=2) + sample_ds = ds.set_task() + self.assertGreater(len(sample_ds), 0) + sample = sample_ds[0] + self.assertIn("signal", sample) + self.assertEqual(sample["signal"].shape, (750,)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_sleep_staging_ibi.py b/tests/core/test_sleep_staging_ibi.py new file mode 100644 index 000000000..1038155f7 --- /dev/null +++ b/tests/core/test_sleep_staging_ibi.py @@ -0,0 +1,178 @@ +"""Unit tests for SleepStagingIBI task.""" + +import math +import sys +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock + +import numpy as np + +from pyhealth.tasks import SleepStagingIBI +from pyhealth.tasks.sleep_staging_ibi import _MAX_EPOCHS, _SAMPLES_PER_EPOCH + + +def _make_event(npz_path: str, ahi: float = 5.0): + event = MagicMock() + event.npz_path = npz_path + event.ahi = ahi + return event + + +def _make_patient(events, pid: str = "S001"): + patient = MagicMock() + patient.patient_id = pid + patient.get_events.return_value = events + return patient + + +def _make_npz( + directory: str, + name: str = "S001", + n_samples: int = 1500, + stages=None, + fs: int = 25, + ahi: float = 3.0, +) -> str: + if stages is None: + stages = np.zeros(n_samples, dtype=np.int32) + for i in range(n_samples): + stages[i] = i % 5 + data = np.random.default_rng(0).random(n_samples).astype(np.float32) + path = str(Path(directory) / f"{name}.npz") + np.savez( + path, + data=data, + stages=stages.astype(np.int32), + fs=np.int64(fs), + ahi=np.float32(ahi), + ) + return path + + +class TestSleepStagingIBI(unittest.TestCase): + + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + self.tmp_path = Path(self._tmpdir.name) + + def tearDown(self): + self._tmpdir.cleanup() + + def test_task_name(self): + self.assertEqual(SleepStagingIBI.task_name, "SleepStagingIBI") + + def test_input_output_schema(self): + self.assertEqual(SleepStagingIBI.input_schema["signal"], "tensor") + self.assertEqual(SleepStagingIBI.output_schema["label"], "multiclass") + + def test_3class_mapping(self): + data_arr = np.zeros(5 * _SAMPLES_PER_EPOCH, dtype=np.float32) + stages_arr = np.zeros(5 * _SAMPLES_PER_EPOCH, dtype=np.int32) + for i, s in enumerate([0, 1, 2, 3, 4]): + stages_arr[i * _SAMPLES_PER_EPOCH:(i + 1) * _SAMPLES_PER_EPOCH] = s + path = str(self.tmp_path / "s.npz") + np.savez( + path, data=data_arr, stages=stages_arr, fs=np.int64(25), ahi=np.float32(0.0) + ) + + task = SleepStagingIBI(num_classes=3) + patient = _make_patient([_make_event(path)]) + labels = [s["label"] for s in task(patient)] + self.assertEqual(labels, [0, 1, 1, 1, 2]) + + def test_5class_mapping(self): + data_arr = np.zeros(5 * _SAMPLES_PER_EPOCH, dtype=np.float32) + stages_arr = np.zeros(5 * _SAMPLES_PER_EPOCH, dtype=np.int32) + for i, s in enumerate([0, 1, 2, 3, 4]): + stages_arr[i * _SAMPLES_PER_EPOCH:(i + 1) * _SAMPLES_PER_EPOCH] = s + path = str(self.tmp_path / "s.npz") + np.savez( + path, data=data_arr, stages=stages_arr, fs=np.int64(25), ahi=np.float32(0.0) + ) + + task = SleepStagingIBI(num_classes=5) + patient = _make_patient([_make_event(path)]) + self.assertEqual([s["label"] for s in task(patient)], [0, 1, 2, 3, 4]) + + def test_invalid_label_skipped(self): + stages_arr = np.full(_SAMPLES_PER_EPOCH, -1, dtype=np.int32) + data_arr = np.zeros(_SAMPLES_PER_EPOCH, dtype=np.float32) + path = str(self.tmp_path / "s.npz") + np.savez( + path, data=data_arr, stages=stages_arr, fs=np.int64(25), ahi=np.float32(0.0) + ) + + task = SleepStagingIBI(num_classes=3) + patient = _make_patient([_make_event(path)]) + self.assertEqual(task(patient), []) + + def test_max_epochs_cap(self): + n = (_MAX_EPOCHS + 50) * _SAMPLES_PER_EPOCH + data_arr = np.zeros(n, dtype=np.float32) + stages_arr = np.zeros(n, dtype=np.int32) + path = str(self.tmp_path / "s.npz") + np.savez( + path, data=data_arr, stages=stages_arr, fs=np.int64(25), ahi=np.float32(0.0) + ) + + task = SleepStagingIBI(num_classes=3) + patient = _make_patient([_make_event(path)]) + self.assertEqual(len(task(patient)), _MAX_EPOCHS) + + def test_empty_on_short_record(self): + data_arr = np.zeros(100, dtype=np.float32) + stages_arr = np.zeros(100, dtype=np.int32) + path = str(self.tmp_path / "s.npz") + np.savez( + path, data=data_arr, stages=stages_arr, fs=np.int64(25), ahi=np.float32(0.0) + ) + + task = SleepStagingIBI() + patient = _make_patient([_make_event(path)]) + self.assertEqual(task(patient), []) + + def test_signal_shape(self): + path = _make_npz(str(self.tmp_path), n_samples=3 * _SAMPLES_PER_EPOCH) + task = SleepStagingIBI() + patient = _make_patient([_make_event(path)]) + samples = task(patient) + for s in samples: + self.assertEqual(s["signal"].shape, (_SAMPLES_PER_EPOCH,)) + + def test_ahi_nan_passthrough(self): + data_arr = np.zeros(_SAMPLES_PER_EPOCH, dtype=np.float32) + stages_arr = np.zeros(_SAMPLES_PER_EPOCH, dtype=np.int32) + path = str(self.tmp_path / "s.npz") + np.savez( + path, + data=data_arr, + stages=stages_arr, + fs=np.int64(25), + ahi=np.float32(float("nan")), + ) + + task = SleepStagingIBI() + event = _make_event(path, ahi=float("nan")) + patient = _make_patient([event]) + samples = task(patient) + self.assertEqual(len(samples), 1) + self.assertTrue(math.isnan(samples[0]["ahi"])) + + def test_wrong_fs_raises(self): + data_arr = np.zeros(_SAMPLES_PER_EPOCH, dtype=np.float32) + stages_arr = np.zeros(_SAMPLES_PER_EPOCH, dtype=np.int32) + path = str(self.tmp_path / "s.npz") + np.savez( + path, data=data_arr, stages=stages_arr, fs=np.int64(50), ahi=np.float32(0.0) + ) + + task = SleepStagingIBI() + patient = _make_patient([_make_event(path)]) + with self.assertRaisesRegex(ValueError, "fs=50"): + task(patient) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_watchsleepnet.py b/tests/core/test_watchsleepnet.py new file mode 100644 index 000000000..ab51e6527 --- /dev/null +++ b/tests/core/test_watchsleepnet.py @@ -0,0 +1,82 @@ +"""Unit tests for WatchSleepNet model.""" + +import unittest + +import torch + +from pyhealth.models.watchsleepnet import WatchSleepNet + + +class TestWatchSleepNet(unittest.TestCase): + + def test_instantiation_defaults(self): + model = WatchSleepNet() + self.assertIsNotNone(model) + + def test_instantiation_3class(self): + model = WatchSleepNet(num_classes=3) + signal = torch.randn(2, 750) + out = model(signal=signal) + self.assertEqual(out["y_prob"].shape, (2, 3)) + + def test_forward_shape(self): + model = WatchSleepNet(num_classes=5) + signal = torch.randn(4, 750) + out = model(signal=signal) + self.assertEqual(out["y_prob"].shape, (4, 5)) + + def test_forward_with_label(self): + model = WatchSleepNet(num_classes=3) + signal = torch.randn(4, 750) + label = torch.randint(0, 3, (4,)) + out = model(signal=signal, label=label) + self.assertEqual(out["loss"].ndim, 0) + out["loss"].backward() + + def test_forward_without_label(self): + model = WatchSleepNet(num_classes=3) + signal = torch.randn(4, 750) + out = model(signal=signal) + self.assertEqual(float(out["loss"]), 0.0) + self.assertIn("y_true", out) + self.assertIsNone(out["y_true"]) + + def test_wrong_input_length(self): + model = WatchSleepNet() + signal = torch.randn(4, 500) + with self.assertRaisesRegex(ValueError, "750"): + model(signal=signal) + + def test_invalid_lstm_hidden(self): + with self.assertRaisesRegex(ValueError, "BiLSTM"): + WatchSleepNet(hidden_dim=256, lstm_hidden=100) + + def test_gradients_flow(self): + model = WatchSleepNet(num_classes=3) + model.train() + signal = torch.randn(4, 750) + label = torch.randint(0, 3, (4,)) + out = model(signal=signal, label=label) + out["loss"].backward() + for name, param in model.named_parameters(): + if param.requires_grad and param.numel() > 0: + self.assertIsNotNone(param.grad, f"No grad for {name}") + + def test_output_dict_keys(self): + model = WatchSleepNet() + signal = torch.randn(2, 750) + out = model(signal=signal) + self.assertEqual(set(out.keys()), {"loss", "y_prob", "y_true"}) + + def test_3class_5class_num_params(self): + model3 = WatchSleepNet(num_classes=3) + model5 = WatchSleepNet(num_classes=5) + params3 = sum(p.numel() for p in model3.parameters()) + params5 = sum(p.numel() for p in model5.parameters()) + fc_weight_diff = (5 - 3) * model3.hidden_dim + fc_bias_diff = 5 - 3 + self.assertEqual(params5 - params3, fc_weight_diff + fc_bias_diff) + + +if __name__ == "__main__": + unittest.main()