From 41ff181887d2e443977648d8217456a875f71be2 Mon Sep 17 00:00:00 2001 From: hexasix Date: Mon, 6 Apr 2026 15:24:30 +0800 Subject: [PATCH] ECG delineation task (Park et al. CHIL 2025) Replicates the ECG wave delineation benchmark from: Park et al., "Benchmarking ECG Delineation using Deep Neural Network-based Semantic Segmentation Models," CHIL 2025. Changes: - Datasets: LUBD and QTDB - Task: ecg_delineation_ludb_fn task - Models: ecg_code Links: Benchmarking ECG Delineation: https://raw.githubusercontent.com/mlresearch/v287/main/assets/park25a/park25a.pdf Self-trained Model for ECG Complex Delineation: https://arxiv.org/pdf/2406.02711 --- docs/api/datasets.rst | 1 + .../pyhealth.datasets.LUDBDataset.rst | 11 + docs/api/tasks.rst | 1 + .../tasks/pyhealth.tasks.ecg_delineation.rst | 6 + examples/LUDB_ECGDelineationLUDB_ECGCODE.py | 523 ++++++++++++++++++ examples/LUDB_ECGDelineationLUDB_RNN.py | 415 ++++++++++++++ examples/ecg_visualization.py | 384 +++++++++++++ examples/ludb_ecg_delineation_unet1d.py | 261 +++++++++ pyhealth/datasets/__init__.py | 10 +- pyhealth/datasets/configs/ludb.yaml | 27 + pyhealth/datasets/configs/qtdb.yaml | 20 + pyhealth/datasets/ludb.py | 468 ++++++++++++++++ pyhealth/datasets/qtdb.py | 287 ++++++++++ pyhealth/models/__init__.py | 24 +- pyhealth/models/ecg_code.py | 336 +++++++++++ pyhealth/tasks/__init__.py | 18 +- pyhealth/tasks/ecg_delineation.py | 405 ++++++++++++++ tests/test_ecg_code.py | 166 ++++++ tests/test_ecg_delineation.py | 260 +++++++++ tests/test_ludb.py | 397 +++++++++++++ 20 files changed, 3997 insertions(+), 23 deletions(-) create mode 100644 docs/api/datasets/pyhealth.datasets.LUDBDataset.rst create mode 100644 docs/api/tasks/pyhealth.tasks.ecg_delineation.rst create mode 100644 examples/LUDB_ECGDelineationLUDB_ECGCODE.py create mode 100644 examples/LUDB_ECGDelineationLUDB_RNN.py create mode 100644 examples/ecg_visualization.py create mode 100644 examples/ludb_ecg_delineation_unet1d.py create mode 100644 pyhealth/datasets/configs/ludb.yaml create mode 100644 pyhealth/datasets/configs/qtdb.yaml create mode 100644 pyhealth/datasets/ludb.py create mode 100644 pyhealth/datasets/qtdb.py create mode 100644 pyhealth/models/ecg_code.py create mode 100644 pyhealth/tasks/ecg_delineation.py create mode 100644 tests/test_ecg_code.py create mode 100644 tests/test_ecg_delineation.py create mode 100644 tests/test_ludb.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..c69f3af03 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -226,6 +226,7 @@ Available Datasets datasets/pyhealth.datasets.MIMIC4Dataset datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset + datasets/pyhealth.datasets.LUDBDataset datasets/pyhealth.datasets.eICUDataset datasets/pyhealth.datasets.ISRUCDataset datasets/pyhealth.datasets.MIMICExtractDataset diff --git a/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst b/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst new file mode 100644 index 000000000..faddb92d2 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.LUDBDataset +============================= + +The Lobachevsky University Database (LUDB) contains 200 12-lead ECG recordings at 500 Hz from healthy volunteers and patients with cardiovascular diseases, with manual cardiologist annotations of P wave, QRS complex, and T wave boundaries. Refer to the `dataset page `_ for more information. + +.. autoclass:: pyhealth.datasets.LUDBDataset + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: pyhealth.datasets.get_stratified_ludb_split diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..729203c00 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -209,6 +209,7 @@ Available Tasks In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding Cardiology Detection + ECG Delineation (LUDB) COVID-19 CXR Classification DKA Prediction (MIMIC-IV) Drug Recommendation diff --git a/docs/api/tasks/pyhealth.tasks.ecg_delineation.rst b/docs/api/tasks/pyhealth.tasks.ecg_delineation.rst new file mode 100644 index 000000000..656449a1b --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ecg_delineation.rst @@ -0,0 +1,6 @@ +pyhealth.tasks.ecg_delineation +=============================== + +ECG wave delineation task for the LUDB dataset. Segments each 10-second ECG lead signal into background (0), P wave (1), QRS complex (2), and T wave (3) regions. Designed for replication of Park et al., "Benchmarking ECG Delineation using Deep Neural Network-based Semantic Segmentation Models," CHIL 2025. + +.. autofunction:: pyhealth.tasks.ecg_delineation_ludb_fn diff --git a/examples/LUDB_ECGDelineationLUDB_ECGCODE.py b/examples/LUDB_ECGDelineationLUDB_ECGCODE.py new file mode 100644 index 000000000..9b91b192e --- /dev/null +++ b/examples/LUDB_ECGDelineationLUDB_ECGCODE.py @@ -0,0 +1,523 @@ +""" +ECG-CODE training example for ECG delineation on LUDB or QTDB. + +This script demonstrates how to: +1) load LUDB or QTDB with the modern BaseDataset API +2) build ECG delineation samples with masks +3) train ECG-CODE with a manual epoch loop +4) track epoch-wise loss, accuracy, and f1_micro +5) plot loss/accuracy/f1_micro over epochs + +Notes +----- +ECG-CODE produces interval-level outputs shaped [B, N, 3, 3]: +- axis -2 (size 3): wave classes (P, QRS, T) +- axis -1 (size 3): (confidence, start, end) + +For accuracy/f1_micro tracking in this example, we evaluate only interval-level +presence confidence: +- y_true: target confidence channel (0/1) +- y_pred: predicted confidence channel thresholded by --pred-conf-threshold +""" + +from __future__ import annotations + +import argparse +import os +import random +from typing import Dict, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.optim import Adam +from tqdm import tqdm + +from pyhealth.datasets import LUDBDataset, QTDBDataset, get_dataloader, split_by_patient +from pyhealth.models import ECGCODE +from pyhealth.tasks import ECGDelineationLUDB, ECGDelineationQTDB + + +def set_seed(seed: int = 42) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Train ECG-CODE on ECG delineation samples (LUDB or QTDB) with " + "epoch-wise loss/accuracy/f1_micro tracking and plotting." + ) + ) + + # Dataset/task args + parser.add_argument( + "--dataset", + type=str, + default="ludb", + choices=["ludb", "qtdb"], + help="Dataset backend to use.", + ) + parser.add_argument("--root", type=str, required=True, help="Dataset root path.") + parser.add_argument( + "--download", + action="store_true", + help="Download dataset from PhysioNet if local files are missing.", + ) + parser.add_argument("--dev", action="store_true", help="Enable dev mode.") + parser.add_argument( + "--cache-dir", + type=str, + default="", + help="Optional cache directory for dataset/task artifacts.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of workers for task processing.", + ) + + # Delineation sample construction + pulse_group = parser.add_mutually_exclusive_group() + pulse_group.add_argument( + "--split-by-pulse", + dest="split_by_pulse", + action="store_true", + help="Use pulse-level windows.", + ) + pulse_group.add_argument( + "--no-split-by-pulse", + dest="split_by_pulse", + action="store_false", + help="Use full-record samples.", + ) + parser.set_defaults(split_by_pulse=True) + + parser.add_argument( + "--pulse-window", + type=int, + default=250, + help=( + "Half-window around QRS peak in samples when pulse split is enabled " + "(250 -> 500-sample pulse)." + ), + ) + parser.add_argument( + "--keep-incomplete-pulses", + action="store_true", + help=( + "If set, do NOT filter pulse samples missing P/QRS/T annotations " + "(effective only with --split-by-pulse)." + ), + ) + + # Model args + parser.add_argument( + "--width-mult", + type=float, + default=1.0, + help="Width multiplier.", + ) + parser.add_argument( + "--interval-size", + type=int, + default=16, + help="Interval size used by ECG-CODE for interval-level predictions.", + ) + parser.add_argument( + "--conf-tolerance", + type=float, + default=0.25, + help="Confidence tolerance threshold in ECG-CODE loss.", + ) + parser.add_argument( + "--se-tolerance", + type=float, + default=0.15, + help="Start/end tolerance threshold in ECG-CODE loss.", + ) + + # Training args + parser.add_argument("--batch-size", type=int, default=64, help="Batch size.") + parser.add_argument("--epochs", type=int, default=10, help="Training epochs.") + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for Adam optimizer.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help='Device override, e.g. "cuda:0" or "cpu". Default: auto.', + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + + # Metric/plot args + parser.add_argument( + "--pred-conf-threshold", + type=float, + default=0.5, + help="Threshold on predicted interval confidence for metric computation.", + ) + parser.add_argument( + "--plot-path", + type=str, + default="", + help=( + "Output path for epoch metrics plot " + "(default: outputs/ecg_code/_interval_epoch_metrics.png)." + ), + ) + parser.add_argument( + "--plot-title", + type=str, + default="", + help="Optional custom title for the epoch metrics plot.", + ) + parser.add_argument( + "--plot-dpi", + type=int, + default=150, + help="DPI for the saved plot.", + ) + + args = parser.parse_args() + + if args.pulse_window <= 0: + parser.error("--pulse-window must be positive.") + if args.interval_size <= 0: + parser.error("--interval-size must be positive.") + if args.conf_tolerance < 0: + parser.error("--conf-tolerance must be non-negative.") + if args.se_tolerance < 0: + parser.error("--se-tolerance must be non-negative.") + if args.epochs <= 0: + parser.error("--epochs must be positive.") + if args.lr <= 0: + parser.error("--lr must be positive.") + if not (0.0 <= args.pred_conf_threshold <= 1.0): + parser.error("--pred-conf-threshold must be within [0, 1].") + if args.plot_dpi <= 0: + parser.error("--plot-dpi must be positive.") + + return args + + +def build_dataset_and_task( + args: argparse.Namespace, +) -> Tuple[ + Union[LUDBDataset, QTDBDataset], + Union[ECGDelineationLUDB, ECGDelineationQTDB], +]: + dataset_kwargs = { + "root": args.root, + "dev": args.dev, + "num_workers": args.num_workers, + "download": args.download, + } + if args.cache_dir: + dataset_kwargs["cache_dir"] = args.cache_dir + + common_task_kwargs = { + "split_by_pulse": args.split_by_pulse, + "pulse_window": args.pulse_window, + "filter_incomplete_pulses": not args.keep_incomplete_pulses, + } + + if args.dataset == "ludb": + base_dataset = LUDBDataset(**dataset_kwargs) + task = ECGDelineationLUDB(**common_task_kwargs) + else: + base_dataset = QTDBDataset(**dataset_kwargs) + task = ECGDelineationQTDB(**common_task_kwargs) + + return base_dataset, task + + +def _batch_binary_stats_from_interval_conf( + y_true_interval: np.ndarray, + y_prob_interval: np.ndarray, + pred_conf_threshold: float, +) -> Dict[str, float]: + """ + Computes TP/FP/FN and accuracy counts from interval confidence channel. + + Inputs are expected as [B, N, 3, 3], where channel index 0 is confidence. + """ + if y_true_interval.ndim != 4 or y_true_interval.shape[-2:] != (3, 3): + raise ValueError( + f"Expected y_true shape [B, N, 3, 3], got {y_true_interval.shape}." + ) + if y_prob_interval.ndim != 4 or y_prob_interval.shape[-2:] != (3, 3): + raise ValueError( + f"Expected y_prob shape [B, N, 3, 3], got {y_prob_interval.shape}." + ) + + y_true_conf = y_true_interval[..., 0] >= 0.5 + y_pred_conf = y_prob_interval[..., 0] >= pred_conf_threshold + + tp = float(np.logical_and(y_pred_conf, y_true_conf).sum()) + fp = float(np.logical_and(y_pred_conf, np.logical_not(y_true_conf)).sum()) + fn = float(np.logical_and(np.logical_not(y_pred_conf), y_true_conf).sum()) + + correct = float((y_pred_conf == y_true_conf).sum()) + total = float(y_true_conf.size) + + return {"tp": tp, "fp": fp, "fn": fn, "correct": correct, "total": total} + + +def _finalize_binary_metrics( + stats: Dict[str, float], eps: float = 1e-8 +) -> Dict[str, float]: + accuracy = stats["correct"] / (stats["total"] + eps) + f1_micro = (2.0 * stats["tp"]) / ( + 2.0 * stats["tp"] + stats["fp"] + stats["fn"] + eps + ) + return {"accuracy": float(accuracy), "f1_micro": float(f1_micro)} + + +def evaluate_epoch( + model: ECGCODE, + dataloader, + pred_conf_threshold: float, +) -> Dict[str, float]: + model.eval() + + loss_all = [] + agg = {"tp": 0.0, "fp": 0.0, "fn": 0.0, "correct": 0.0, "total": 0.0} + + with torch.no_grad(): + for data in tqdm(dataloader, desc="Eval", leave=False): + output = model(**data) + loss_all.append(float(output["loss"].item())) + + y_true = output["y_true"].detach().cpu().numpy() + y_prob = output["y_prob"].detach().cpu().numpy() + batch_stats = _batch_binary_stats_from_interval_conf( + y_true_interval=y_true, + y_prob_interval=y_prob, + pred_conf_threshold=pred_conf_threshold, + ) + for k in agg: + agg[k] += batch_stats[k] + + metrics = _finalize_binary_metrics(agg) + metrics["loss"] = float(np.mean(loss_all)) if len(loss_all) > 0 else float("nan") + return metrics + + +def plot_epoch_metrics( + history: Dict[str, list], + output_path: str, + title: str, + dpi: int = 150, +) -> None: + epochs = np.arange(1, len(history["train_loss"]) + 1) + + fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) + + # Loss + axes[0].plot(epochs, history["train_loss"], marker="o", label="Train") + axes[0].plot(epochs, history["val_loss"], marker="o", label="Val") + axes[0].set_title("Loss") + axes[0].set_xlabel("Epoch") + axes[0].set_ylabel("Loss") + axes[0].grid(alpha=0.3) + axes[0].legend() + + # Accuracy + axes[1].plot(epochs, history["train_accuracy"], marker="o", label="Train") + axes[1].plot(epochs, history["val_accuracy"], marker="o", label="Val") + axes[1].set_title("Accuracy") + axes[1].set_xlabel("Epoch") + axes[1].set_ylabel("Score") + axes[1].set_ylim(0.0, 1.0) + axes[1].grid(alpha=0.3) + axes[1].legend() + + # F1 Micro + axes[2].plot(epochs, history["train_f1_micro"], marker="o", label="Train") + axes[2].plot(epochs, history["val_f1_micro"], marker="o", label="Val") + axes[2].set_title("F1 Micro") + axes[2].set_xlabel("Epoch") + axes[2].set_ylabel("Score") + axes[2].set_ylim(0.0, 1.0) + axes[2].grid(alpha=0.3) + axes[2].legend() + + fig.suptitle(title) + fig.tight_layout() + + parent_dir = os.path.dirname(output_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + + fig.savefig(output_path, dpi=dpi) + plt.close(fig) + print(f"Saved epoch-metrics plot: {output_path}") + + +def main() -> None: + args = parse_args() + set_seed(args.seed) + + device = ( + args.device + if args.device is not None + else ("cuda" if torch.cuda.is_available() else "cpu") + ) + + # 1) Build base dataset + delineation task + base_dataset, task = build_dataset_and_task(args) + base_dataset.stats() + if hasattr(base_dataset, "info"): + base_dataset.info() + + # 2) Build sample dataset + sample_dataset = base_dataset.set_task(task, num_workers=args.num_workers) + print(f"Number of delineation samples: {len(sample_dataset)}") + if len(sample_dataset) == 0: + if args.split_by_pulse: + raise RuntimeError( + "No samples were generated. Try --keep-incomplete-pulses, " + "increase --pulse-window, or use --no-split-by-pulse." + ) + raise RuntimeError( + "No samples were generated for full-record delineation. " + "Check dataset root and annotations." + ) + + # 3) Split + dataloaders + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + train_loader = get_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) + val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + # 4) Model + model = ECGCODE( + dataset=sample_dataset, + signal_key="signal", + mask_key="mask", + width_mult=args.width_mult, + interval_size=args.interval_size, + conf_tolerance=args.conf_tolerance, + se_tolerance=args.se_tolerance, + ) + model.to(device) + + # 5) Optimizer + optimizer = Adam(model.parameters(), lr=args.lr) + + # 6) Train loop with epoch-wise metrics + history: Dict[str, list] = { + "train_loss": [], + "train_accuracy": [], + "train_f1_micro": [], + "val_loss": [], + "val_accuracy": [], + "val_f1_micro": [], + } + + for epoch in range(1, args.epochs + 1): + model.train() + + epoch_loss = [] + agg = {"tp": 0.0, "fp": 0.0, "fn": 0.0, "correct": 0.0, "total": 0.0} + + for data in tqdm( + train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=False + ): + optimizer.zero_grad() + output = model(**data) + loss = output["loss"] + loss.backward() + optimizer.step() + + epoch_loss.append(float(loss.item())) + y_true = output["y_true"].detach().cpu().numpy() + y_prob = output["y_prob"].detach().cpu().numpy() + + batch_stats = _batch_binary_stats_from_interval_conf( + y_true_interval=y_true, + y_prob_interval=y_prob, + pred_conf_threshold=args.pred_conf_threshold, + ) + for k in agg: + agg[k] += batch_stats[k] + + train_metrics = _finalize_binary_metrics(agg) + train_metrics["loss"] = ( + float(np.mean(epoch_loss)) if len(epoch_loss) > 0 else float("nan") + ) + + val_metrics = evaluate_epoch( + model=model, + dataloader=val_loader, + pred_conf_threshold=args.pred_conf_threshold, + ) + + history["train_loss"].append(train_metrics["loss"]) + history["train_accuracy"].append(train_metrics["accuracy"]) + history["train_f1_micro"].append(train_metrics["f1_micro"]) + history["val_loss"].append(val_metrics["loss"]) + history["val_accuracy"].append(val_metrics["accuracy"]) + history["val_f1_micro"].append(val_metrics["f1_micro"]) + + print( + f"Epoch {epoch:03d} | " + f"train_loss={train_metrics['loss']:.6f}, " + f"train_acc={train_metrics['accuracy']:.6f}, " + f"train_f1_micro={train_metrics['f1_micro']:.6f} | " + f"val_loss={val_metrics['loss']:.6f}, " + f"val_acc={val_metrics['accuracy']:.6f}, " + f"val_f1_micro={val_metrics['f1_micro']:.6f}" + ) + + # 7) Final test metrics + test_metrics = evaluate_epoch( + model=model, + dataloader=test_loader, + pred_conf_threshold=args.pred_conf_threshold, + ) + print("Final test metrics:") + print(f"loss: {test_metrics['loss']:.6f}") + print(f"accuracy: {test_metrics['accuracy']:.6f}") + print(f"f1_micro: {test_metrics['f1_micro']:.6f}") + + # 8) Plot epoch curves + plot_path = args.plot_path or ( + f"outputs/ecg_code/{args.dataset}_interval{args.interval_size}_epoch_metrics.png" + ) + plot_title = args.plot_title or f"ECG-CODE Epoch Metrics ({args.dataset.upper()})" + plot_epoch_metrics( + history=history, + output_path=plot_path, + title=plot_title, + dpi=args.plot_dpi, + ) + + +# python examples/LUDB_ECGDelineationLUDB_ECGCODE.py \ +# --dataset ludb \ +# --root /home/$USER/.cache/pyhealth/datasets/physionet.org/files/ludb/1.0.1/data \ +# --dev \ +# --epochs 10 \ +# --lr 1e-3 \ +# --pred-conf-threshold 0.5 \ +# --plot-path ./ludb_ecgcode_epoch_metrics.png + +if __name__ == "__main__": + main() diff --git a/examples/LUDB_ECGDelineationLUDB_RNN.py b/examples/LUDB_ECGDelineationLUDB_RNN.py new file mode 100644 index 000000000..2b51d8932 --- /dev/null +++ b/examples/LUDB_ECGDelineationLUDB_RNN.py @@ -0,0 +1,415 @@ +""" +ECG RNN training example for LUDB or QTDB (pulse-split classification). + +This example shows how to: +1) load LUDB or QTDB with the modern BaseDataset API +2) apply a pulse-splitting delineation task +3) train a PyHealth RNN model on pulse windows +4) track epoch-wise loss, accuracy, and f1_micro +5) plot loss/accuracy/f1_micro curves over epochs + +Task setup +---------- +- Input: pulse ECG window ("signal"), shape (1, T) +- Label: dominant wave class in the pulse mask ("label"), multiclass in {0,1,2,3} + 0=background, 1=P, 2=QRS, 3=T +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Dict, List, Type, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.optim import Adam +from tqdm import tqdm + +from pyhealth.datasets import LUDBDataset, QTDBDataset, get_dataloader, split_by_patient +from pyhealth.metrics import multiclass_metrics_fn +from pyhealth.models import RNN +from pyhealth.tasks import BaseTask, ECGDelineationLUDB, ECGDelineationQTDB + + +class PulseRNNTask(BaseTask): + """Pulse-level ECG classification wrapper for RNN training.""" + + task_name: str = "ecg_pulse_rnn_classification" + input_schema: Dict[str, Union[str, Type]] = {"signal": "tensor"} + output_schema: Dict[str, Union[str, Type]] = {"label": "multiclass"} + + def __init__(self, base_task: BaseTask) -> None: + super().__init__() + self.base_task = base_task + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + base_samples = self.base_task(patient) + + samples: List[Dict[str, Any]] = [] + for s in base_samples: + samples.append( + { + "patient_id": s["patient_id"], + "record_id": s["record_id"], + "lead": s.get("lead"), + "signal": s["signal"], + "label": int(s["label"]), + } + ) + return samples + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Train RNN on ECG delineation pulse windows (LUDB or QTDB) " + "with epoch-wise loss/accuracy/f1_micro tracking and plotting." + ) + ) + parser.add_argument( + "--dataset", + type=str, + default="ludb", + choices=["ludb", "qtdb"], + help="Dataset backend to use.", + ) + parser.add_argument("--root", type=str, required=True, help="Dataset root path.") + parser.add_argument( + "--download", + action="store_true", + help="Download dataset from PhysioNet if local files are missing.", + ) + parser.add_argument("--dev", action="store_true", help="Enable dev mode.") + parser.add_argument( + "--cache-dir", + type=str, + default="", + help="Optional cache directory for dataset/task artifacts.", + ) + parser.add_argument( + "--pulse-window", + type=int, + default=250, + help="Half-window around QRS peak in samples (250 => 500-sample pulse).", + ) + parser.add_argument( + "--keep-incomplete-pulses", + action="store_true", + help="If set, do NOT filter pulses missing P or T annotations.", + ) + parser.add_argument("--batch-size", type=int, default=64, help="Batch size.") + parser.add_argument("--epochs", type=int, default=10, help="Training epochs.") + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for Adam optimizer.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of workers for task processing.", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=128, + help="RNN embedding dimension.", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=128, + help="RNN hidden dimension.", + ) + parser.add_argument( + "--rnn-type", + type=str, + default="GRU", + choices=["RNN", "LSTM", "GRU"], + help="Recurrent cell type.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help='Device override, e.g. "cuda:0" or "cpu". Default: auto.', + ) + parser.add_argument( + "--plot-path", + type=str, + default="", + help=( + "Optional output path for the epoch-metrics plot " + "(default: output/_rnn_epoch_metrics.png)." + ), + ) + parser.add_argument( + "--plot-title", + type=str, + default="", + help="Optional custom title for the epoch-metrics plot.", + ) + + args = parser.parse_args() + if args.epochs <= 0: + parser.error("--epochs must be positive.") + if args.lr <= 0: + parser.error("--lr must be positive.") + if args.pulse_window <= 0: + parser.error("--pulse-window must be positive.") + return args + + +def build_dataset_and_task(args: argparse.Namespace): + dataset_kwargs = { + "root": args.root, + "dev": args.dev, + "num_workers": args.num_workers, + "download": args.download, + } + if args.cache_dir: + dataset_kwargs["cache_dir"] = args.cache_dir + + if args.dataset == "ludb": + base_dataset = LUDBDataset(**dataset_kwargs) + base_task = ECGDelineationLUDB( + split_by_pulse=True, + pulse_window=args.pulse_window, + filter_incomplete_pulses=not args.keep_incomplete_pulses, + ) + else: + base_dataset = QTDBDataset(**dataset_kwargs) + base_task = ECGDelineationQTDB( + split_by_pulse=True, + pulse_window=args.pulse_window, + filter_incomplete_pulses=not args.keep_incomplete_pulses, + ) + + task = PulseRNNTask(base_task) + return base_dataset, task + + +def _compute_multiclass_scores( + y_true_all: np.ndarray, + y_prob_all: np.ndarray, +) -> Dict[str, float]: + scores = multiclass_metrics_fn( + y_true=y_true_all, + y_prob=y_prob_all, + metrics=["accuracy", "f1_micro"], + ) + return { + "accuracy": float(scores["accuracy"]), + "f1_micro": float(scores["f1_micro"]), + } + + +def evaluate_epoch( + model: RNN, + dataloader, +) -> Dict[str, float]: + model.eval() + loss_all: List[float] = [] + y_true_all: List[np.ndarray] = [] + y_prob_all: List[np.ndarray] = [] + + with torch.no_grad(): + for data in tqdm(dataloader, desc="Validation/Test", leave=False): + output = model(**data) + loss_all.append(float(output["loss"].item())) + y_true_all.append(output["y_true"].detach().cpu().numpy()) + y_prob_all.append(output["y_prob"].detach().cpu().numpy()) + + y_true = np.concatenate(y_true_all, axis=0) + y_prob = np.concatenate(y_prob_all, axis=0) + cls_scores = _compute_multiclass_scores(y_true, y_prob) + + return { + "loss": float(np.mean(loss_all)), + "accuracy": cls_scores["accuracy"], + "f1_micro": cls_scores["f1_micro"], + } + + +def plot_epoch_metrics( + history: Dict[str, List[float]], + output_path: str, + title: str, +) -> None: + epochs = np.arange(1, len(history["train_loss"]) + 1) + + fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) + + # Loss + axes[0].plot(epochs, history["train_loss"], marker="o", label="Train") + axes[0].plot(epochs, history["val_loss"], marker="o", label="Val") + axes[0].set_title("Loss") + axes[0].set_xlabel("Epoch") + axes[0].set_ylabel("Loss") + axes[0].grid(alpha=0.3) + axes[0].legend() + + # Accuracy + axes[1].plot(epochs, history["train_accuracy"], marker="o", label="Train") + axes[1].plot(epochs, history["val_accuracy"], marker="o", label="Val") + axes[1].set_title("Accuracy") + axes[1].set_xlabel("Epoch") + axes[1].set_ylabel("Score") + axes[1].set_ylim(0.0, 1.0) + axes[1].grid(alpha=0.3) + axes[1].legend() + + # F1 Micro + axes[2].plot(epochs, history["train_f1_micro"], marker="o", label="Train") + axes[2].plot(epochs, history["val_f1_micro"], marker="o", label="Val") + axes[2].set_title("F1 Micro") + axes[2].set_xlabel("Epoch") + axes[2].set_ylabel("Score") + axes[2].set_ylim(0.0, 1.0) + axes[2].grid(alpha=0.3) + axes[2].legend() + + fig.suptitle(title) + fig.tight_layout() + + out = Path(output_path) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out, dpi=200) + plt.close(fig) + print(f"Saved epoch-metrics plot to: {out}") + + +def main() -> None: + args = parse_args() + device = ( + args.device + if args.device is not None + else ("cuda" if torch.cuda.is_available() else "cpu") + ) + + # 1) Base dataset + delineation task + base_dataset, task = build_dataset_and_task(args) + base_dataset.stats() + if hasattr(base_dataset, "info"): + base_dataset.info() + + # 2) Build pulse-level sample dataset + sample_dataset = base_dataset.set_task(task, num_workers=args.num_workers) + print(f"Number of pulse samples: {len(sample_dataset)}") + if len(sample_dataset) == 0: + raise RuntimeError( + "No samples were generated. Try disabling pulse filtering " + "with --keep-incomplete-pulses or using a different pulse window." + ) + + # 3) Split + dataloaders + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + train_loader = get_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) + val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader( + test_dataset, batch_size=args.batch_size, shuffle=False + ) + + # 4) Model + model = RNN( + dataset=sample_dataset, + embedding_dim=args.embedding_dim, + hidden_dim=args.hidden_dim, + rnn_type=args.rnn_type, + num_layers=1, + dropout=0.1, + bidirectional=False, + ) + model.to(device) + + # 5) Optimizer + optimizer = Adam(model.parameters(), lr=args.lr) + + # 6) Train loop with epoch-wise metrics + history: Dict[str, List[float]] = { + "train_loss": [], + "train_accuracy": [], + "train_f1_micro": [], + "val_loss": [], + "val_accuracy": [], + "val_f1_micro": [], + } + + for epoch in range(1, args.epochs + 1): + model.train() + epoch_loss: List[float] = [] + epoch_y_true: List[np.ndarray] = [] + epoch_y_prob: List[np.ndarray] = [] + + for data in tqdm( + train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=False + ): + optimizer.zero_grad() + output = model(**data) + loss = output["loss"] + loss.backward() + optimizer.step() + + epoch_loss.append(float(loss.item())) + epoch_y_true.append(output["y_true"].detach().cpu().numpy()) + epoch_y_prob.append(output["y_prob"].detach().cpu().numpy()) + + train_y_true = np.concatenate(epoch_y_true, axis=0) + train_y_prob = np.concatenate(epoch_y_prob, axis=0) + train_cls_scores = _compute_multiclass_scores(train_y_true, train_y_prob) + + train_scores = { + "loss": float(np.mean(epoch_loss)), + "accuracy": train_cls_scores["accuracy"], + "f1_micro": train_cls_scores["f1_micro"], + } + val_scores = evaluate_epoch(model, val_loader) + + history["train_loss"].append(train_scores["loss"]) + history["train_accuracy"].append(train_scores["accuracy"]) + history["train_f1_micro"].append(train_scores["f1_micro"]) + history["val_loss"].append(val_scores["loss"]) + history["val_accuracy"].append(val_scores["accuracy"]) + history["val_f1_micro"].append(val_scores["f1_micro"]) + + print( + f"Epoch {epoch:03d} | " + f"train_loss={train_scores['loss']:.6f}, " + f"train_acc={train_scores['accuracy']:.6f}, " + f"train_f1_micro={train_scores['f1_micro']:.6f} | " + f"val_loss={val_scores['loss']:.6f}, " + f"val_acc={val_scores['accuracy']:.6f}, " + f"val_f1_micro={val_scores['f1_micro']:.6f}" + ) + + # 7) Final test metrics + test_scores = evaluate_epoch(model, test_loader) + print("Final test metrics:") + print(f"loss: {test_scores['loss']:.6f}") + print(f"accuracy: {test_scores['accuracy']:.6f}") + print(f"f1_micro: {test_scores['f1_micro']:.6f}") + + # 8) Plot curves + plot_path = args.plot_path or f"output/{args.dataset}_rnn_epoch_metrics.png" + plot_title = args.plot_title or f"RNN Epoch Metrics ({args.dataset.upper()})" + plot_epoch_metrics(history, output_path=plot_path, title=plot_title) + + +# python examples/LUDB_ECGDelineationLUDB_RNN.py \ +# --dataset ludb \ +# --root /home/$USER/.cache/pyhealth/datasets/physionet.org/files/ludb/1.0.1/data \ +# --epochs 10 \ +# --lr 1e-3 \ +# --plot-path ./ludb_rnn_epoch_metrics.png +# +# +if __name__ == "__main__": + main() diff --git a/examples/ecg_visualization.py b/examples/ecg_visualization.py new file mode 100644 index 000000000..7c50ef39e --- /dev/null +++ b/examples/ecg_visualization.py @@ -0,0 +1,384 @@ +""" +Reusable ECG visualization example using task outputs with pulse/non-pulse comparison. + +This script is designed to be extensible across ECG datasets. It currently includes +LUDB integration and compares: + +1) Non-pulse mode (full window per lead) +2) Pulse mode (all pulse-centered windows per lead) + +The plotting utilities are dataset-agnostic as long as task samples provide: + - "signal": 1D or (C, T) array/tensor + - "mask": 1D segmentation labels + - optional metadata keys like "patient_id", "record_id", "lead" + +Usage: + python examples/ecg_visualization.py \ + --dataset ludb \ + --root /path/to/physionet.org/files/ludb/1.0.1/data \ + --patient-id 1 \ + --lead i \ + --pulse-window 250 \ + --filter-incomplete-pulses \ + --dev \ + --save-path ./ecg_compare_all_pulses.png + +Requirements: + pip install matplotlib + (and dataset/task dependencies, e.g. wfdb for LUDB delineation) +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np + +from pyhealth.datasets import LUDBDataset +from pyhealth.tasks import ECGDelineationLUDB + +# --------------------------------------------------------------------- +# Generic utilities (dataset/task-agnostic) +# --------------------------------------------------------------------- + +CLASS_NAMES = { + 0: "background", + 1: "P", + 2: "QRS", + 3: "T", +} + +CLASS_COLORS = { + 0: "#9E9E9E", + 1: "#4CAF50", + 2: "#F44336", + 3: "#2196F3", +} + + +def _to_numpy(x: Any) -> np.ndarray: + """Convert tensor/array/list-like object to np.ndarray safely.""" + if x is None: + raise ValueError("Expected array-like input, got None.") + + # Torch-like tensor support without hard dependency + if hasattr(x, "detach") and hasattr(x, "cpu"): + x = x.detach().cpu().numpy() + + return np.asarray(x) + + +def _signal_1d(signal: Any) -> np.ndarray: + """Normalize signal to shape (T,) for visualization.""" + arr = _to_numpy(signal) + + if arr.ndim == 1: + return arr.astype(np.float32) + + if arr.ndim == 2: + # expected ECG delineation shape is (1, T), but handle generic multi-channel + if arr.shape[0] == 1: + return arr[0].astype(np.float32) + return arr[0].astype(np.float32) + + raise ValueError(f"Unsupported signal shape: {arr.shape}") + + +def _mask_1d(mask: Any) -> np.ndarray: + """Normalize mask to shape (T,) with integer labels.""" + arr = _to_numpy(mask) + if arr.ndim != 1: + arr = arr.reshape(-1) + return arr.astype(np.int64) + + +def _collect_samples( + sample_dataset: Any, + patient_id: str, + lead: Optional[str] = None, +) -> List[Dict[str, Any]]: + """ + Collect all samples for a patient (optionally filtered by lead). + + Uses `patient_to_index` if present (SampleDataset), otherwise scans dataset. + """ + pid = str(patient_id) + candidates: List[Dict[str, Any]] = [] + + if hasattr(sample_dataset, "patient_to_index"): + indices = sample_dataset.patient_to_index.get(pid, []) + for idx in indices: + s = sample_dataset[idx] + if lead is not None and str(s.get("lead", "")).lower() != lead.lower(): + continue + if "signal" in s and "mask" in s: + candidates.append(s) + else: + # fallback scan + for idx in range(len(sample_dataset)): + s = sample_dataset[idx] + if str(s.get("patient_id", "")) != pid: + continue + if lead is not None and str(s.get("lead", "")).lower() != lead.lower(): + continue + if "signal" in s and "mask" in s: + candidates.append(s) + + return candidates + + +def _plot_signal_with_mask( + ax: plt.Axes, + signal: np.ndarray, + mask: np.ndarray, + title: str, + alpha: float = 0.25, +) -> None: + """Plot ECG signal and color-overlay segmentation mask classes.""" + t = np.arange(len(signal)) + y_min, y_max = float(signal.min()), float(signal.max()) + if np.isclose(y_min, y_max): + y_min -= 1e-3 + y_max += 1e-3 + + ax.plot(t, signal, linewidth=1.0, color="black") + + # overlay class regions + unique_classes = sorted(np.unique(mask).tolist()) + for cls in unique_classes: + cls_int = int(cls) + cls_mask = mask == cls_int + if not np.any(cls_mask): + continue + color = CLASS_COLORS.get(cls_int, "#BDBDBD") + ax.fill_between( + t, + y_min, + y_max, + where=cls_mask, + color=color, + alpha=alpha, + step="mid", + label=f"{cls_int}:{CLASS_NAMES.get(cls_int, 'unknown')}", + ) + + ax.set_title(title) + ax.set_xlabel("Sample index") + ax.set_ylabel("Amplitude") + ax.grid(alpha=0.2) + + # deduplicate legend labels + handles, labels = ax.get_legend_handles_labels() + dedup = {} + for h, l in zip(handles, labels): + dedup[l] = h + if dedup: + ax.legend(dedup.values(), dedup.keys(), fontsize=8, loc="upper right") + + +def visualize_comparison_all_pulses( + raw_sample: Dict[str, Any], + pulse_samples: List[Dict[str, Any]], + save_path: Optional[Path] = None, + max_pulses: int = 0, +) -> None: + """ + Create stacked comparison: + - first row: non-pulse (full record) sample + - next rows: all pulse samples (or first `max_pulses` if > 0) + """ + if not pulse_samples: + raise ValueError("No pulse samples to visualize.") + + if max_pulses > 0: + pulse_samples = pulse_samples[:max_pulses] + + nrows = 1 + len(pulse_samples) + fig, axes = plt.subplots( + nrows, 1, figsize=(14, max(4, 2.8 * nrows)), constrained_layout=True + ) + if nrows == 1: + axes = [axes] + + # Raw sample panel + raw_signal = _signal_1d(raw_sample["signal"]) + raw_mask = _mask_1d(raw_sample["mask"]) + raw_title = ( + f"Non-pulse mode | patient={raw_sample.get('patient_id')} " + f"| lead={raw_sample.get('lead')} | record_id={raw_sample.get('record_id')} " + f"| T={len(raw_signal)}" + ) + _plot_signal_with_mask(axes[0], raw_signal, raw_mask, raw_title) + + # Pulse panels + for i, pulse_sample in enumerate(pulse_samples, start=1): + pulse_signal = _signal_1d(pulse_sample["signal"]) + pulse_mask = _mask_1d(pulse_sample["mask"]) + pulse_title = ( + f"Pulse {i}/{len(pulse_samples)} | patient={pulse_sample.get('patient_id')} " + f"| lead={pulse_sample.get('lead')} | record_id={pulse_sample.get('record_id')} " + f"| T={len(pulse_signal)}" + ) + _plot_signal_with_mask(axes[i], pulse_signal, pulse_mask, pulse_title) + + fig.suptitle("ECG Delineation: Non-pulse vs All Pulse-aligned Splits", fontsize=14) + + if save_path is not None: + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=180) + print(f"Saved figure to: {save_path}") + + plt.show() + + +# --------------------------------------------------------------------- +# Dataset/task adapters +# --------------------------------------------------------------------- + + +def build_ludb_sample_datasets( + root: str, + dev: bool, + num_workers: int, + pulse_window: int, + filter_incomplete_pulses: bool, +): + """Build non-pulse and pulse-mode SampleDatasets for LUDB.""" + dataset = LUDBDataset(root=root, dev=dev, num_workers=num_workers) + + raw_task = ECGDelineationLUDB( + split_by_pulse=False, + pulse_window=pulse_window, + filter_incomplete_pulses=False, + ) + pulse_task = ECGDelineationLUDB( + split_by_pulse=True, + pulse_window=pulse_window, + filter_incomplete_pulses=filter_incomplete_pulses, + ) + + raw_ds = dataset.set_task(raw_task, num_workers=num_workers) + pulse_ds = dataset.set_task(pulse_task, num_workers=num_workers) + return raw_ds, pulse_ds + + +# Future extension point: +# Add adapters for other ECG datasets (e.g., QTDB) with the same return interface. +DATASET_BUILDERS = { + "ludb": build_ludb_sample_datasets, +} + + +# --------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Reusable ECG visualization from task outputs (non-pulse vs all pulse splits)." + ) + parser.add_argument( + "--dataset", + type=str, + default="ludb", + choices=sorted(DATASET_BUILDERS.keys()), + help="ECG dataset adapter to use.", + ) + parser.add_argument( + "--root", + type=str, + required=True, + help="Dataset root directory.", + ) + parser.add_argument( + "--patient-id", + type=str, + default="1", + help="Patient ID to visualize (record ID for LUDB).", + ) + parser.add_argument( + "--lead", + type=str, + default="i", + help="Lead name to visualize (e.g., i, ii, v1).", + ) + parser.add_argument( + "--pulse-window", + type=int, + default=250, + help="Pulse half-window in samples for pulse mode.", + ) + parser.add_argument( + "--filter-incomplete-pulses", + action="store_true", + help="If set, keep only pulse windows that contain P, QRS, and T labels.", + ) + parser.add_argument( + "--max-pulses", + type=int, + default=0, + help="Optional cap on number of pulse windows to plot (0 = show all).", + ) + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Workers for set_task processing.", + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dataset dev mode.", + ) + parser.add_argument( + "--save-path", + type=str, + default="", + help="Optional output image path. If empty, no file is saved.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + builder = DATASET_BUILDERS[args.dataset] + raw_ds, pulse_ds = builder( + root=args.root, + dev=args.dev, + num_workers=args.num_workers, + pulse_window=args.pulse_window, + filter_incomplete_pulses=args.filter_incomplete_pulses, + ) + + raw_samples = _collect_samples(raw_ds, patient_id=args.patient_id, lead=args.lead) + pulse_samples = _collect_samples( + pulse_ds, patient_id=args.patient_id, lead=args.lead + ) + + if not raw_samples: + raise ValueError( + f"No non-pulse sample found for patient_id='{args.patient_id}', lead='{args.lead}'." + ) + if not pulse_samples: + raise ValueError( + "No pulse samples found. Try disabling --filter-incomplete-pulses " + "or selecting a different patient/lead." + ) + + raw_sample = raw_samples[0] + save_path = Path(args.save_path) if args.save_path else None + visualize_comparison_all_pulses( + raw_sample=raw_sample, + pulse_samples=pulse_samples, + save_path=save_path, + max_pulses=args.max_pulses, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/ludb_ecg_delineation_unet1d.py b/examples/ludb_ecg_delineation_unet1d.py new file mode 100644 index 000000000..7a5bf0acb --- /dev/null +++ b/examples/ludb_ecg_delineation_unet1d.py @@ -0,0 +1,261 @@ +""" +LUDB ECG Delineation — Ablation Study +====================================== +This script demonstrates an end-to-end pipeline for ECG wave delineation using +the LUDBDataset and ecg_delineation_ludb_fn task, then ablates key design choices +using a lightweight 1-D U-Net model. + +Ablations performed +------------------- +1. Input mode : pulse-aligned windows vs. raw 10-second windows +2. Pulse-window size: 125 / 250 / 375 samples (pulse-aligned mode only) +3. U-Net depth : 2 / 3 encoder stages + +Run +--- + python examples/ludb_ecg_delineation_unet1d.py + +Requirements +------------ + pip install wfdb torch pyhealth + +Set DATA_ROOT to the ``data/`` folder from the LUDB PhysioNet download. +""" + +import pickle +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset + +# ── PyHealth imports ──────────────────────────────────────────────────────── +from pyhealth.datasets.ludb import LUDBDataset +from pyhealth.tasks.ecg_delineation import ecg_delineation_ludb_fn + +# ── Configuration ──────────────────────────────────────────────────────────── +DATA_ROOT = "/Users/delin/Documents/DL4H/physionet.org/files/ludb/1.0.1/data" +NUM_CLASSES = 4 # 0=background, 1=P, 2=QRS, 3=T +BATCH_SIZE = 16 +EPOCHS = 5 # keep short for demonstration +LR = 1e-3 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# ── Lightweight 1-D U-Net ──────────────────────────────────────────────────── + +class ConvBlock(nn.Module): + """Two Conv1d → BN → ReLU layers.""" + + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm1d(out_ch), + nn.ReLU(inplace=True), + nn.Conv1d(out_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm1d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class UNet1D(nn.Module): + """Encoder–decoder U-Net for 1-D sequence segmentation. + + Args: + in_channels: number of input channels (1 for single-lead ECG). + num_classes: number of output segmentation classes. + base_filters: number of filters in the first encoder block. + depth: number of encoder / decoder stages. + """ + + def __init__( + self, + in_channels: int = 1, + num_classes: int = NUM_CLASSES, + base_filters: int = 32, + depth: int = 3, + ): + super().__init__() + self.depth = depth + + enc_channels = [in_channels] + [base_filters * (2 ** i) for i in range(depth)] + self.encoders = nn.ModuleList( + [ConvBlock(enc_channels[i], enc_channels[i + 1]) for i in range(depth)] + ) + self.pools = nn.ModuleList( + [nn.MaxPool1d(2) for _ in range(depth)] + ) + + self.bottleneck = ConvBlock(enc_channels[-1], enc_channels[-1] * 2) + + dec_channels = [enc_channels[-1] * 2] + list(reversed(enc_channels[1:])) + self.upconvs = nn.ModuleList( + [nn.ConvTranspose1d(dec_channels[i], dec_channels[i + 1], kernel_size=2, stride=2) + for i in range(depth)] + ) + self.decoders = nn.ModuleList( + [ConvBlock(dec_channels[i + 1] * 2, dec_channels[i + 1]) for i in range(depth)] + ) + + self.head = nn.Conv1d(dec_channels[-1], num_classes, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + skips = [] + for enc, pool in zip(self.encoders, self.pools): + x = enc(x) + skips.append(x) + x = pool(x) + x = self.bottleneck(x) + for upconv, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)): + x = upconv(x) + # handle odd-length mismatches + if x.shape[-1] != skip.shape[-1]: + x = nn.functional.pad(x, (0, skip.shape[-1] - x.shape[-1])) + x = dec(torch.cat([skip, x], dim=1)) + return self.head(x) + + +# ── Torch dataset wrapper ──────────────────────────────────────────────────── + +class EpochDataset(Dataset): + """Wraps SampleSignalDataset samples for direct use with DataLoader.""" + + def __init__(self, samples): + self.samples = samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + epoch = pickle.load(open(self.samples[idx]["epoch_path"], "rb")) + signal = torch.tensor(epoch["signal"], dtype=torch.float32) # (1, T) + label = torch.tensor(epoch["label"], dtype=torch.long) # (T,) + return signal, label + + +def pad_collate(batch): + """Pad signals in a batch to the same length.""" + signals, labels = zip(*batch) + max_len = max(s.shape[-1] for s in signals) + signals = torch.stack([ + nn.functional.pad(s, (0, max_len - s.shape[-1])) for s in signals + ]) + labels = torch.stack([ + nn.functional.pad(l, (0, max_len - l.shape[0])) for l in labels + ]) + return signals, labels + + +# ── Training / evaluation helpers ─────────────────────────────────────────── + +def train_epoch(model, loader, optimizer, criterion): + model.train() + total_loss = 0.0 + for signals, labels in loader: + signals, labels = signals.to(DEVICE), labels.to(DEVICE) + optimizer.zero_grad() + logits = model(signals) # (B, C, T) + loss = criterion(logits, labels) + loss.backward() + optimizer.step() + total_loss += loss.item() * signals.size(0) + return total_loss / len(loader.dataset) + + +@torch.no_grad() +def evaluate(model, loader): + model.eval() + correct = total = 0 + for signals, labels in loader: + signals, labels = signals.to(DEVICE), labels.to(DEVICE) + preds = model(signals).argmax(dim=1) # (B, T) + correct += (preds == labels).sum().item() + total += labels.numel() + return correct / total + + +def run_ablation(name: str, task_fn_kwargs: dict, model_kwargs: dict): + """Build dataset, train for EPOCHS, and return final val accuracy.""" + print(f"\n{'=' * 60}") + print(f"Ablation: {name}") + print(f" task kwargs : {task_fn_kwargs}") + print(f" model kwargs: {model_kwargs}") + + # Build PyHealth SampleSignalDataset + ds = LUDBDataset(root=DATA_ROOT, dev=False, refresh_cache=True) + sample_ds = ds.set_task( + lambda rec: ecg_delineation_ludb_fn(rec, **task_fn_kwargs), + task_name="ecg_delineation", + ) + all_samples = sample_ds.samples + + # 8:1:1 patient split (records 1-160 / 161-180 / 181-200) + train_s = [s for s in all_samples if int(s["patient_id"]) <= 160] + val_s = [s for s in all_samples if 160 < int(s["patient_id"]) <= 180] + + train_loader = DataLoader(EpochDataset(train_s), batch_size=BATCH_SIZE, + shuffle=True, collate_fn=pad_collate) + val_loader = DataLoader(EpochDataset(val_s), batch_size=BATCH_SIZE, + shuffle=False, collate_fn=pad_collate) + + model = UNet1D(**model_kwargs).to(DEVICE) + optimizer = torch.optim.Adam(model.parameters(), lr=LR) + criterion = nn.CrossEntropyLoss() + + for epoch in range(1, EPOCHS + 1): + loss = train_epoch(model, train_loader, optimizer, criterion) + acc = evaluate(model, val_loader) + print(f" Epoch {epoch:2d}/{EPOCHS} — loss: {loss:.4f} val_acc: {acc:.4f}") + + val_acc = evaluate(model, val_loader) + print(f" Final val accuracy: {val_acc:.4f}") + return val_acc + + +# ── Ablation configurations ────────────────────────────────────────────────── + +ABLATIONS = [ + # (name, task_fn_kwargs, model_kwargs) + ( + "pulse-aligned / window=250 / depth=3", + {"use_pulse_aligned": True, "pulse_window": 250}, + {"depth": 3}, + ), + ( + "pulse-aligned / window=125 / depth=3", + {"use_pulse_aligned": True, "pulse_window": 125}, + {"depth": 3}, + ), + ( + "pulse-aligned / window=375 / depth=3", + {"use_pulse_aligned": True, "pulse_window": 375}, + {"depth": 3}, + ), + ( + "pulse-aligned / window=250 / depth=2", + {"use_pulse_aligned": True, "pulse_window": 250}, + {"depth": 2}, + ), + ( + "raw 10-sec window / depth=3", + {"use_pulse_aligned": False}, + {"depth": 3}, + ), +] + + +if __name__ == "__main__": + results = {} + for name, task_kw, model_kw in ABLATIONS: + results[name] = run_ablation(name, task_kw, model_kw) + + print("\n\n" + "=" * 60) + print("Ablation Summary") + print("=" * 60) + for name, acc in results.items(): + print(f" {name:<50s} val_acc = {acc:.4f}") diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..cd4e2fd11 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -47,26 +47,27 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset +from .bmd_hs import BMDHSDataset from .cardiology import CardiologyDataset from .chestxray14 import ChestXray14Dataset from .clinvar import ClinVarDataset +from .collate import collate_temporal from .cosmic import COSMICDataset from .covid19_cxr import COVID19CXRDataset from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset from .isruc import ISRUCDataset +from .ludb import LUDBDataset, get_stratified_ludb_split from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset +from .qtdb import QTDBDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset -from .bmd_hs import BMDHSDataset -from .support2 import Support2Dataset -from .tcga_prad import TCGAPRADDataset from .splitter import ( sample_balanced, split_by_patient, @@ -80,6 +81,8 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) +from .support2 import Support2Dataset +from .tcga_prad import TCGAPRADDataset from .tuab import TUABDataset from .tuev import TUEVDataset from .utils import ( @@ -89,4 +92,3 @@ def __init__(self, *args, **kwargs): load_processors, save_processors, ) -from .collate import collate_temporal diff --git a/pyhealth/datasets/configs/ludb.yaml b/pyhealth/datasets/configs/ludb.yaml new file mode 100644 index 000000000..d870a7f22 --- /dev/null +++ b/pyhealth/datasets/configs/ludb.yaml @@ -0,0 +1,27 @@ +version: "1.0.0" +tables: + ludb: + file_path: "ludb-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "visit_id" + - "record_id" + - "signal_file" + - "fs" + - "n_samples" + - "lead_i" + - "lead_ii" + - "lead_iii" + - "lead_avr" + - "lead_avl" + - "lead_avf" + - "lead_v1" + - "lead_v2" + - "lead_v3" + - "lead_v4" + - "lead_v5" + - "lead_v6" + - "rhythm" + - "electric_axis" + - "no_p" diff --git a/pyhealth/datasets/configs/qtdb.yaml b/pyhealth/datasets/configs/qtdb.yaml new file mode 100644 index 000000000..c6648690b --- /dev/null +++ b/pyhealth/datasets/configs/qtdb.yaml @@ -0,0 +1,20 @@ +version: "1.0.0" + +tables: + qtdb: + file_path: "qtdb-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "visit_id" + - "record_id" + - "signal_file" + - "lead_0" + - "lead_1" + - "ann_atr" + - "ann_man" + - "ann_qt1" + - "ann_q1c" + - "ann_pu" + - "ann_pu0" + - "ann_pu1" diff --git a/pyhealth/datasets/ludb.py b/pyhealth/datasets/ludb.py new file mode 100644 index 000000000..d8012332d --- /dev/null +++ b/pyhealth/datasets/ludb.py @@ -0,0 +1,468 @@ +import os +import random +import subprocess +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd + +from .base_dataset import BaseDataset + +LEADS = ["i", "ii", "iii", "avr", "avl", "avf", "v1", "v2", "v3", "v4", "v5", "v6"] +NUM_RECORDS = 200 +DEV_NUM_RECORDS = 10 +FS = 500 +N_SAMPLES = 5000 +LUDB_PHYSIONET_URL = "https://physionet.org/files/ludb/1.0.1/" +DEFAULT_DOWNLOAD_ROOT = Path.home() / ".cache" / "pyhealth" / "datasets" + + +class LUDBDataset(BaseDataset): + """Base ECG dataset for the Lobachevsky University Database (LUDB). + + Modernized LUDB dataset built on :class:`pyhealth.datasets.BaseDataset`. + + Design: + - One table: ``ludb`` + - One row/event per LUDB record (record ID 1..200) + - ``patient_id`` is the record ID + - 12 lead annotation files are represented as 12 columns: + ``lead_i``, ``lead_ii``, ..., ``lead_v6`` + + Notes: + - Pulse-level splitting is intentionally handled in task logic + (e.g., ECG delineation task), not at base dataset level. + - Metadata CSV (``ludb-pyhealth.csv``) is generated automatically if + missing. + + Dataset: + https://physionet.org/content/ludb/1.0.1/ + + Args: + root: Local root path for LUDB. This can be: + - the LUDB ``data/`` directory containing ``*.hea``, ``*.dat``, + and ``*.`` files, or + - any parent directory to probe for existing LUDB files. + If ``download=True`` and local files are not found, LUDB is + downloaded to ``Path.home() / ".cache" / "pyhealth" / "datasets"``. + dataset_name: Optional dataset name. Defaults to ``"ludb"``. + config_path: Optional YAML config path. Defaults to + ``pyhealth/datasets/configs/ludb.yaml``. + dev: If True, uses only the first 10 records. + download: If True and LUDB files are not found locally, downloads LUDB + from PhysioNet using: + ``wget -r -N -c -np https://physionet.org/files/ludb/1.0.1/``. + refresh_cache: Deprecated compatibility argument from legacy signal API. + It is accepted for backward compatibility but ignored under + BaseDataset. + **kwargs: Forwarded to :class:`BaseDataset` (e.g., ``cache_dir``, + ``num_workers``). + + Examples: + >>> from pyhealth.datasets import LUDBDataset + >>> dataset = LUDBDataset(root="/path/to/physionet.org/files/ludb/1.0.1/data") + >>> dataset.stats() # BaseDataset-native dataset statistics + >>> dataset.info() # LUDB-specific summary (table/leads/paths) + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + dev: bool = False, + download: bool = False, + refresh_cache: Optional[bool] = None, + **kwargs, + ) -> None: + if refresh_cache is not None: + warnings.warn( + "`refresh_cache` is deprecated for LUDBDataset with BaseDataset and is ignored.", + DeprecationWarning, + stacklevel=2, + ) + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "ludb.yaml" + ) + + data_root = self._resolve_data_root(root=root, download=download) + metadata_root = self._prepare_metadata(root=data_root) + + super().__init__( + root=metadata_root, + tables=["ludb"], + dataset_name=dataset_name or "ludb", + config_path=config_path, + dev=dev, + **kwargs, + ) + + # --------------------------------------------------------------------- + # Backward-compat convenience methods + # --------------------------------------------------------------------- + def stat(self) -> None: + """Backward-compatible alias for :meth:`stats`.""" + warnings.warn( + "`stat()` is deprecated; use `stats()` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.stats() + + def info(self) -> None: + """Print LUDB summary for the modern BaseDataset/event-table layout.""" + print(f"Dataset: {self.dataset_name}") + print("Backend: BaseDataset") + print("Event table: ludb (one event per LUDB record)") + print(f"Root (metadata source): {self.root}") + print(f"Tables: {self.tables}") + print(f"Dev mode: {self.dev}") + print(f"Lead columns: {', '.join(f'lead_{lead}' for lead in LEADS)}") + print("Use `dataset.stats()` for patient/event counts.") + + def preprocess_ludb(self, df: Any) -> Any: + """Apply LUDB-specific preprocessing before event caching.""" + if not self.dev: + return df + try: + import importlib + + nw = importlib.import_module("narwhals") + except Exception: + return df + return df.filter(nw.col("patient_id").cast(nw.Int64) <= DEV_NUM_RECORDS) + + # --------------------------------------------------------------------- + # Metadata generation + # --------------------------------------------------------------------- + @staticmethod + def _candidate_data_roots(root: str) -> List[Path]: + """Return possible LUDB data directories for local or downloaded layouts.""" + root_path = Path(root).expanduser().resolve() + return [ + root_path, + root_path / "data", + root_path / "physionet.org" / "files" / "ludb" / "1.0.1" / "data", + root_path / "files" / "ludb" / "1.0.1" / "data", + ] + + @staticmethod + def _looks_like_ludb_data_dir(path: Path) -> bool: + """Check whether a directory looks like LUDB data/ (contains numeric .hea).""" + if not path.exists() or not path.is_dir(): + return False + try: + return any(p.suffix == ".hea" and p.stem.isdigit() for p in path.iterdir()) + except OSError: + return False + + @staticmethod + def _download_ludb(download_root: Path) -> None: + """Download LUDB from PhysioNet using wget recursive mirror command.""" + cmd = ["wget", "-r", "-N", "-c", "-np", LUDB_PHYSIONET_URL] + try: + subprocess.run(cmd, cwd=str(download_root), check=True) + except FileNotFoundError as e: + raise RuntimeError( + "wget is not installed or not in PATH. Please install wget and retry." + ) from e + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"LUDB download failed via wget command: {' '.join(cmd)}" + ) from e + + @classmethod + def _resolve_data_root(cls, root: str, download: bool) -> str: + """Resolve LUDB data directory, optionally downloading from PhysioNet.""" + for candidate in cls._candidate_data_roots(root): + if cls._looks_like_ludb_data_dir(candidate): + return str(candidate) + + if not download: + raise FileNotFoundError( + "LUDB data files not found. Provide a valid local LUDB data path " + "or set download=True to fetch from PhysioNet." + ) + + download_root = Path(DEFAULT_DOWNLOAD_ROOT).expanduser().resolve() + download_root.mkdir(parents=True, exist_ok=True) + cls._download_ludb(download_root) + + for candidate in cls._candidate_data_roots(str(download_root)): + if cls._looks_like_ludb_data_dir(candidate): + return str(candidate) + + raise FileNotFoundError( + "LUDB download completed but data directory could not be resolved." + ) + + @staticmethod + def _resolve_ludb_csv(root: str) -> Optional[str]: + """Resolve path to PhysioNet ``ludb.csv`` if available.""" + root_path = Path(root) + candidates = [ + root_path / "ludb.csv", + root_path.parent / "ludb.csv", + ] + for path in candidates: + if path.exists(): + return str(path) + return None + + @staticmethod + def _discover_record_ids(root: str) -> List[int]: + """Discover record IDs from ``*.hea`` files in LUDB data directory.""" + ids: List[int] = [] + for filename in os.listdir(root): + if not filename.endswith(".hea"): + continue + stem = filename[:-4] + if stem.isdigit(): + ids.append(int(stem)) + return sorted(set(ids)) + + def _prepare_metadata(self, root: str) -> str: + """Build ``ludb-pyhealth.csv`` if it does not exist. + + Returns: + Directory path where ``ludb-pyhealth.csv`` is located. + """ + root_path = Path(root) + metadata_path = root_path / "ludb-pyhealth.csv" + if metadata_path.exists(): + return str(root_path) + + ludb_csv_path = self._resolve_ludb_csv(root) + + # Optional richer metadata from PhysioNet ludb.csv + meta_by_id: Dict[int, Dict[str, object]] = {} + if ludb_csv_path is not None: + raw_df = pd.read_csv(ludb_csv_path) + + # Derive no_p from Rhythms if available + if "Rhythms" in raw_df.columns: + afib_pattern = r"Atrial fibrillation|Atrial flutter" + raw_df["no_p"] = ( + raw_df["Rhythms"] + .astype(str) + .str.contains(afib_pattern, na=False) + .astype(int) + ) + raw_df["rhythm_norm"] = ( + raw_df["Rhythms"].astype(str).str.split("\n").str[0].str.strip() + ) + else: + raw_df["no_p"] = 0 + raw_df["rhythm_norm"] = None + + # Normalize electric axis if available + axis_col = "Electric axis of the heart" + if axis_col in raw_df.columns: + raw_df["axis_norm"] = ( + raw_df[axis_col] + .astype(str) + .str.split(":") + .str[-1] + .str.strip() + .replace({"nan": None}) + ) + else: + raw_df["axis_norm"] = None + + if "ID" in raw_df.columns: + subset = raw_df[ + ["ID", "rhythm_norm", "axis_norm", "no_p"] + ].values.tolist() + + for rid_raw, rhythm_raw, axis_raw, no_p_raw in subset: + if pd.isna(rid_raw): + continue + + try: + rid = int(rid_raw) + except (TypeError, ValueError): + continue + + try: + no_p = 0 if pd.isna(no_p_raw) else int(no_p_raw) + except (TypeError, ValueError): + no_p = 0 + + meta_by_id[rid] = { + "rhythm": rhythm_raw, + "electric_axis": axis_raw, + "no_p": no_p, + } + + record_ids = self._discover_record_ids(root) + if not record_ids: + # Fallback when files are not visible yet + record_ids = list(range(1, NUM_RECORDS + 1)) + + rows: List[Dict[str, object]] = [] + for rid in record_ids: + pid = str(rid) + + row: Dict[str, object] = { + "patient_id": pid, + "visit_id": "ecg", + "record_id": pid, + # Absolute WFDB base record path, without extension + "signal_file": str(root_path / pid), + "fs": FS, + "n_samples": N_SAMPLES, + "rhythm": None, + "electric_axis": None, + "no_p": 0, + } + + # 12 lead columns + for lead in LEADS: + col = f"lead_{lead}" + # store absolute annotation path + row[col] = str(root_path / f"{pid}.{lead}") + + # enrich from ludb.csv if present + if rid in meta_by_id: + row.update(meta_by_id[rid]) + + rows.append(row) + + df = pd.DataFrame(rows) + df.sort_values(["patient_id"], inplace=True) + df.reset_index(drop=True, inplace=True) + + try: + df.to_csv(metadata_path, index=False) + return str(root_path) + except (PermissionError, OSError): + cache_root = Path.home() / ".cache" / "pyhealth" / "ludb" + cache_root.mkdir(parents=True, exist_ok=True) + cache_metadata_path = cache_root / "ludb-pyhealth.csv" + df.to_csv(cache_metadata_path, index=False) + return str(cache_root) + + @property + def default_task(self): + """No default task is enforced for LUDB.""" + return None + + +def get_stratified_ludb_split( + ludb_csv: str, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + seed: int = 42, +) -> Tuple[List[int], List[int], List[int]]: + """Return stratified train/val/test LUDB record-ID splits. + + Stratification key: + ``(no_p, rhythm_norm, axis_norm)`` + + Rules: + - Groups with count >= 10 are split proportionally. + - Smaller groups are pooled and split proportionally together. + + Args: + ludb_csv: Path to PhysioNet ``ludb.csv``. + train_ratio: Train fraction. + val_ratio: Validation fraction. + seed: Random seed. + + Returns: + (train_ids, val_ids, test_ids): sorted record IDs (1-indexed). + """ + if train_ratio < 0 or val_ratio < 0 or train_ratio + val_ratio > 1: + raise ValueError( + "train_ratio and val_ratio must be non-negative and sum to <= 1." + ) + + rng = random.Random(seed) + df = pd.read_csv(ludb_csv) + + if "ID" not in df.columns: + raise ValueError("Expected column 'ID' in ludb.csv") + if "Rhythms" not in df.columns: + raise ValueError("Expected column 'Rhythms' in ludb.csv") + if "Electric axis of the heart" not in df.columns: + raise ValueError("Expected column 'Electric axis of the heart' in ludb.csv") + + # no_p=1 if AFib/flutter appears in rhythms + afib_pattern = r"Atrial fibrillation|Atrial flutter" + df["no_p"] = df["Rhythms"].str.contains(afib_pattern, na=False).astype(int) + + # primary rhythm: first line + df["rhythm_norm"] = df["Rhythms"].str.split("\n").str[0].str.strip() + + # normalize axis text + def _axis(val: str) -> str: + if pd.isna(val): + return "unknown" + return str(val).split(":")[-1].strip() + + df["axis_norm"] = df["Electric axis of the heart"].apply(_axis) + + df["strat_key"] = list(zip(df["no_p"], df["rhythm_norm"], df["axis_norm"])) + record_ids = df["ID"].astype(int).tolist() + + groups: Dict[tuple, List[int]] = defaultdict(list) + for rid, key in zip(record_ids, df["strat_key"]): + groups[key].append(rid) + + train_ids: List[int] = [] + val_ids: List[int] = [] + test_ids: List[int] = [] + + threshold = 10 + small_pool: List[int] = [] + + for _, ids in groups.items(): + rng.shuffle(ids) + n = len(ids) + if n >= threshold: + n_train = round(n * train_ratio) + n_val = round(n * val_ratio) + train_ids.extend(ids[:n_train]) + val_ids.extend(ids[n_train : n_train + n_val]) + test_ids.extend(ids[n_train + n_val :]) + else: + small_pool.extend(ids) + + rng.shuffle(small_pool) + n = len(small_pool) + n_train = round(n * train_ratio) + n_val = round(n * val_ratio) + train_ids.extend(small_pool[:n_train]) + val_ids.extend(small_pool[n_train : n_train + n_val]) + test_ids.extend(small_pool[n_train + n_val :]) + + return sorted(train_ids), sorted(val_ids), sorted(test_ids) + + +if __name__ == "__main__": + CACHE_ROOT = ( + Path.home() + / ".cache" + / "pyhealth" + / "datasets" + / "physionet.org" + / "files" + / "ludb" + / "1.0.1" + ) + DATA_ROOT = os.environ.get("DATA_ROOT", str(CACHE_ROOT / "data")) + LUDB_CSV = os.environ.get("LUDB_CSV", str(CACHE_ROOT / "ludb.csv")) + + dataset = LUDBDataset(root=DATA_ROOT, dev=True, download=True) + dataset.stats() + dataset.info() + + if os.path.exists(LUDB_CSV): + train_ids, val_ids, test_ids = get_stratified_ludb_split(LUDB_CSV) + print( + f"Stratified split: train={len(train_ids)}, val={len(val_ids)}, test={len(test_ids)}" + ) diff --git a/pyhealth/datasets/qtdb.py b/pyhealth/datasets/qtdb.py new file mode 100644 index 000000000..143e4e109 --- /dev/null +++ b/pyhealth/datasets/qtdb.py @@ -0,0 +1,287 @@ +import os +import subprocess +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +QTDB_PHYSIONET_URL = "https://physionet.org/files/qtdb/1.0.0/" +DEFAULT_DOWNLOAD_ROOT = Path.home() / ".cache" / "pyhealth" / "datasets" +NUM_RECORDS = 105 +DEV_NUM_RECORDS = 10 + + +class QTDBDataset(BaseDataset): + """Base ECG dataset for the PhysioNet QT Database (QTDB). + + QTDB contains 105 two-lead 15-minute ECG recordings with multiple + annotation files (reference beat annotations and waveform delineation + outputs from manual/automatic methods). + + Dataset: + https://physionet.org/content/qtdb/1.0.0/ + + Design: + - One table: ``qtdb`` + - One row/event per record + - ``patient_id`` is the record ID (e.g., ``sel100``) + + Args: + root: Local path to QTDB data directory, or any parent directory to probe. + If ``download=True`` and data is not found locally, QTDB is downloaded + to ``Path.home() / ".cache" / "pyhealth" / "datasets"``. + dataset_name: Optional dataset name. Defaults to ``"qtdb"``. + config_path: Optional YAML config path. Defaults to + ``pyhealth/datasets/configs/qtdb.yaml``. + dev: If True, keep only first 10 records. + download: If True and local data is missing, download QTDB using: + ``wget -r -N -c -np https://physionet.org/files/qtdb/1.0.0/``. + refresh_cache: Deprecated legacy argument; accepted for compatibility + but ignored. + **kwargs: Forwarded to :class:`BaseDataset` (e.g., ``cache_dir``, + ``num_workers``). + + Examples: + >>> from pyhealth.datasets import QTDBDataset + >>> ds = QTDBDataset(root="~/.cache/pyhealth/datasets", download=True, dev=True) + >>> ds.stats() + >>> ds.info() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + dev: bool = False, + download: bool = False, + refresh_cache: Optional[bool] = None, + **kwargs, + ) -> None: + if refresh_cache is not None: + warnings.warn( + "`refresh_cache` is deprecated for QTDBDataset with BaseDataset and is ignored.", + DeprecationWarning, + stacklevel=2, + ) + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "qtdb.yaml" + ) + + data_root = self._resolve_data_root(root=root, download=download) + metadata_root = self._prepare_metadata(root=data_root) + + super().__init__( + root=metadata_root, + tables=["qtdb"], + dataset_name=dataset_name or "qtdb", + config_path=config_path, + dev=dev, + **kwargs, + ) + + # --------------------------------------------------------------------- + # Backward-compat convenience methods + # --------------------------------------------------------------------- + def stat(self) -> None: + """Backward-compatible alias for :meth:`stats`.""" + warnings.warn( + "`stat()` is deprecated; use `stats()` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.stats() + + def info(self) -> None: + """Print QTDB summary for the BaseDataset/event-table layout.""" + print(f"Dataset: {self.dataset_name}") + print("Backend: BaseDataset") + print("Event table: qtdb (one event per QTDB record)") + print(f"Root (metadata source): {self.root}") + print(f"Tables: {self.tables}") + print(f"Dev mode: {self.dev}") + print("Lead columns: lead_0, lead_1") + print("Use `dataset.stats()` for patient/event counts.") + + def preprocess_qtdb(self, df: Any) -> Any: + """Apply QTDB-specific preprocessing before event caching.""" + if not self.dev: + return df + # Keep first DEV_NUM_RECORDS rows after deterministic lexical sort. + return df.sort("patient_id").head(DEV_NUM_RECORDS) + + # --------------------------------------------------------------------- + # Data root + download helpers + # --------------------------------------------------------------------- + @staticmethod + def _candidate_data_roots(root: str) -> List[Path]: + """Return possible QTDB data directories for local/downloaded layouts.""" + root_path = Path(root).expanduser().resolve() + return [ + root_path, + root_path / "qtdb" / "1.0.0", + root_path / "physionet.org" / "files" / "qtdb" / "1.0.0", + root_path / "files" / "qtdb" / "1.0.0", + ] + + @staticmethod + def _looks_like_qtdb_data_dir(path: Path) -> bool: + """Check whether a directory looks like QTDB.""" + if not path.exists() or not path.is_dir(): + return False + + if (path / "RECORDS").exists(): + return True + + try: + return any( + p.is_file() and p.name.endswith(".hea") and not p.name.endswith(".hea-") + for p in path.iterdir() + ) + except OSError: + return False + + @staticmethod + def _download_qtdb(download_root: Path) -> None: + """Download QTDB from PhysioNet using wget recursive mirror command.""" + cmd = ["wget", "-r", "-N", "-c", "-np", QTDB_PHYSIONET_URL] + try: + subprocess.run(cmd, cwd=str(download_root), check=True) + except FileNotFoundError as e: + raise RuntimeError( + "wget is not installed or not in PATH. Please install wget and retry." + ) from e + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"QTDB download failed via wget command: {' '.join(cmd)}" + ) from e + + @classmethod + def _resolve_data_root(cls, root: str, download: bool) -> str: + """Resolve QTDB data directory, optionally downloading from PhysioNet.""" + for candidate in cls._candidate_data_roots(root): + if cls._looks_like_qtdb_data_dir(candidate): + return str(candidate) + + if not download: + raise FileNotFoundError( + "QTDB data files not found. Provide a valid local QTDB path " + "or set download=True to fetch from PhysioNet." + ) + + download_root = DEFAULT_DOWNLOAD_ROOT.resolve() + download_root.mkdir(parents=True, exist_ok=True) + cls._download_qtdb(download_root) + + for candidate in cls._candidate_data_roots(str(download_root)): + if cls._looks_like_qtdb_data_dir(candidate): + return str(candidate) + + raise FileNotFoundError( + "QTDB download completed but data directory could not be resolved." + ) + + # --------------------------------------------------------------------- + # Metadata helpers + # --------------------------------------------------------------------- + @staticmethod + def _read_records_list(root_path: Path) -> List[str]: + """Read record IDs from RECORDS file, fallback to .hea discovery.""" + records_file = root_path / "RECORDS" + if records_file.exists(): + records = [ + line.strip() + for line in records_file.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + if records: + return records + + records = [] + for f in root_path.glob("*.hea"): + if f.name.endswith(".hea") and not f.name.endswith(".hea-"): + records.append(f.stem) + return sorted(set(records)) + + @staticmethod + def _ann_path_or_none(base: Path, ext: str) -> Optional[str]: + ann = base.with_suffix(f".{ext}") + return str(ann) if ann.exists() else None + + def _prepare_metadata(self, root: str) -> str: + """Build ``qtdb-pyhealth.csv`` if missing and return metadata root path.""" + root_path = Path(root).expanduser().resolve() + metadata_path = root_path / "qtdb-pyhealth.csv" + if metadata_path.exists(): + return str(root_path) + + records = self._read_records_list(root_path) + if not records: + # conservative fallback if directory probing happens before actual files exist + records = [f"record_{i:03d}" for i in range(NUM_RECORDS)] + + rows: List[Dict[str, object]] = [] + for rec in records: + base = root_path / rec + rows.append( + { + "patient_id": rec, + "visit_id": "ecg", + "record_id": rec, + # absolute WFDB base record path (without extension) + "signal_file": str(base), + # QTDB has two leads per record + "lead_0": "0", + "lead_1": "1", + # commonly used QTDB annotation files + "ann_atr": self._ann_path_or_none(base, "atr"), + "ann_man": self._ann_path_or_none(base, "man"), + "ann_qt1": self._ann_path_or_none(base, "qt1"), + "ann_q1c": self._ann_path_or_none(base, "q1c"), + "ann_pu": self._ann_path_or_none(base, "pu"), + "ann_pu0": self._ann_path_or_none(base, "pu0"), + "ann_pu1": self._ann_path_or_none(base, "pu1"), + } + ) + + df = pd.DataFrame(rows) + df.sort_values(["patient_id"], inplace=True) + df.reset_index(drop=True, inplace=True) + + try: + df.to_csv(metadata_path, index=False) + return str(root_path) + except (PermissionError, OSError): + cache_root = Path.home() / ".cache" / "pyhealth" / "qtdb" + cache_root.mkdir(parents=True, exist_ok=True) + cache_metadata_path = cache_root / "qtdb-pyhealth.csv" + df.to_csv(cache_metadata_path, index=False) + return str(cache_root) + + @property + def default_task(self): + """No default task is enforced for QTDB.""" + return None + + +if __name__ == "__main__": + cache_root = ( + Path.home() + / ".cache" + / "pyhealth" + / "datasets" + / "physionet.org" + / "files" + / "qtdb" + / "1.0.0" + ) + data_root = os.environ.get("DATA_ROOT", str(cache_root)) + + dataset = QTDBDataset(root=data_root, dev=True, download=True) + dataset.stats() + dataset.info() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..5b3f8b3b4 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -6,41 +6,41 @@ from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D from .deepr import Deepr, DeeprLayer +from .ecg_code import ECGCODE +from .ehrmamba import EHRMamba, MambaBlock from .embedding import EmbeddingModel from .gamenet import GAMENet, GAMENetLayer -from .jamba_ehr import JambaEHR, JambaLayer -from .logistic_regression import LogisticRegression from .gan import GAN from .gnn import GAT, GCN from .graph_torchvision_model import Graph_TorchvisionModel from .graphcare import GraphCare from .grasp import GRASP, GRASPLayer +from .jamba_ehr import JambaEHR, JambaLayer +from .logistic_regression import LogisticRegression from .medlink import MedLink from .micron import MICRON, MICRONLayer from .mlp import MLP from .molerec import MoleRec, MoleRecLayer -from .retain import MultimodalRETAIN, RETAIN, RETAINLayer -from .rnn import MultimodalRNN, RNN, RNNLayer +from .retain import RETAIN, MultimodalRETAIN, RETAINLayer +from .rnn import RNN, MultimodalRNN, RNNLayer from .safedrug import SafeDrug, SafeDrugLayer +from .sdoh import SdohClassifier from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer from .tcn import TCN, TCNLayer +from .text_embedding import TextEmbedding from .tfm_tokenizer import ( - TFMTokenizer, - TFM_VQVAE2_deep, TFM_TOKEN_Classifier, - get_tfm_tokenizer_2x2x8, + TFM_VQVAE2_deep, + TFMTokenizer, get_tfm_token_classifier_64x4, + get_tfm_tokenizer_2x2x8, load_embedding_weights, ) from .torchvision_model import TorchvisionModel from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel -from .ehrmamba import EHRMamba, MambaBlock +from .unified_embedding import SinusoidalTimeEmbedding, UnifiedMultimodalEmbeddingModel from .vae import VAE from .vision_embedding import VisionEmbeddingModel -from .text_embedding import TextEmbedding -from .sdoh import SdohClassifier -from .medlink import MedLink -from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding diff --git a/pyhealth/models/ecg_code.py b/pyhealth/models/ecg_code.py new file mode 100644 index 000000000..acff15b01 --- /dev/null +++ b/pyhealth/models/ecg_code.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models.base_model import BaseModel + + +class ConvBNAct1d(nn.Module): + """Conv1d + BatchNorm1d + activation.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + activation: Optional[nn.Module] = None, + ) -> None: + super().__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ) + self.bn = nn.BatchNorm1d(out_channels) + self.act = activation if activation is not None else nn.ReLU6(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(self.bn(self.conv(x))) + + +class InvertedResidual1d(nn.Module): + """MobileNetV2-style inverted residual block for 1D signals.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: int, + ) -> None: + super().__init__() + if stride not in (1, 2): + raise ValueError(f"stride must be 1 or 2, got {stride}") + + hidden_dim = int(round(in_channels * expand_ratio)) + self.use_res_connect = stride == 1 and in_channels == out_channels + + layers = [] + if expand_ratio != 1: + layers.append( + ConvBNAct1d( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + stride=1, + ) + ) + + layers.append( + ConvBNAct1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + groups=hidden_dim, + ) + ) + + layers.append( + ConvBNAct1d( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + stride=1, + activation=nn.Identity(), + ) + ) + + self.block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.block(x) + if self.use_res_connect: + return x + out + return out + + +class ECGCODE(BaseModel): + """ + ECG-CODE-like delineation model with a MobileNet-style 1D backbone. + + The model predicts, for each interval and each ECG object (P/QRS/T): + 1) confidence (presence in interval) + 2) normalized start position in interval + 3) normalized end position in interval + + Output shape (after sigmoid): + [batch, n_intervals, 3_waves, 3_values(conf,start,end)] + + Custom interval loss (inspired by arXiv:2406.02711): + CL = 0 if |pc - tc| < conf_tolerance else (pc - tc)^2 + SS = (ps - ts)^2 + (pe - te)^2 + SEL = 0 if SS < se_tolerance else SS * tc + total = mean(CL + SEL) + """ + + WAVE_LABELS = (1, 2, 3) # P, QRS, T in mask encoding + + def __init__( + self, + dataset: SampleDataset, + signal_key: str = "signal", + mask_key: str = "mask", + width_mult: float = 1.0, + interval_size: int = 16, + conf_tolerance: float = 0.25, + se_tolerance: float = 0.15, + ) -> None: + super().__init__(dataset=dataset) + + if interval_size <= 0: + raise ValueError("interval_size must be positive.") + + self.signal_key = signal_key + self.mask_key = mask_key + self.interval_size = int(interval_size) + self.conf_tolerance = float(conf_tolerance) + self.se_tolerance = float(se_tolerance) + + stem_channels = self._make_divisible(32 * width_mult) + self.stem = ConvBNAct1d(1, stem_channels, kernel_size=7, stride=2) + + # MobileNetV2-like configuration: (expand_ratio, channels, repeats, stride) + cfg = [ + (1, 16, 1, 1), + (6, 24, 2, 2), + (6, 32, 3, 2), + (6, 64, 3, 2), + (6, 96, 2, 1), + (6, 160, 2, 2), + ] + + blocks = [] + in_ch = stem_channels + for t, c, n, s in cfg: + out_ch = self._make_divisible(c * width_mult) + for i in range(n): + stride = s if i == 0 else 1 + blocks.append( + InvertedResidual1d( + in_channels=in_ch, + out_channels=out_ch, + stride=stride, + expand_ratio=t, + ) + ) + in_ch = out_ch + self.backbone = nn.Sequential(*blocks) + + proj_ch = self._make_divisible(128 * width_mult) + self.proj = ConvBNAct1d(in_ch, proj_ch, kernel_size=1, stride=1) + + # 3 waves x 3 outputs(conf,start,end) + self.head = nn.Conv1d(proj_ch, 9, kernel_size=1, bias=True) + + @staticmethod + def _make_divisible(v: float, divisor: int = 8) -> int: + return int(math.ceil(v / divisor) * divisor) + + @staticmethod + def _extract_tensor(value) -> torch.Tensor: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, tuple): + for item in value: + if isinstance(item, torch.Tensor): + return item + raise ValueError("Expected a tensor or tuple containing a tensor.") + + def _normalize_signal_shape(self, x: torch.Tensor) -> torch.Tensor: + """ + Normalize to [B, C, T], where C is typically 1 for ECG. + """ + if x.dim() == 1: + x = x.unsqueeze(0).unsqueeze(0) # [1,1,T] + elif x.dim() == 2: + x = x.unsqueeze(1) # [B,1,T] + elif x.dim() == 3: + # Heuristic for [B,T,C] + if x.shape[-1] <= 4 and x.shape[1] > x.shape[-1]: + x = x.transpose(1, 2) + else: + raise ValueError(f"Unsupported signal shape: {tuple(x.shape)}") + return x.float() + + def _build_interval_targets( + self, + masks: torch.Tensor, + n_intervals: int, + ) -> torch.Tensor: + """ + Build interval targets: [B, n_intervals, 3_waves, 3_values]. + """ + if masks.dim() == 3 and masks.shape[1] == 1: + masks = masks[:, 0, :] + if masks.dim() != 2: + raise ValueError( + f"Expected mask shape [B,T] or [B,1,T], got {tuple(masks.shape)}" + ) + + bsz, seq_len = masks.shape + device = masks.device + targets = torch.zeros( + (bsz, n_intervals, 3, 3), device=device, dtype=torch.float32 + ) + + for b in range(bsz): + mask_b = masks[b].long() + for i in range(n_intervals): + start = i * self.interval_size + end = min((i + 1) * self.interval_size, seq_len) + if start >= end: + continue + + seg = mask_b[start:end] + seg_len = int(seg.numel()) + denom = max(seg_len - 1, 1) + + for w_idx, wave_label in enumerate(self.WAVE_LABELS): + idx = torch.where(seg == wave_label)[0] + if idx.numel() == 0: + continue + + targets[b, i, w_idx, 0] = 1.0 # confidence target + targets[b, i, w_idx, 1] = float(idx.min().item()) / float(denom) + targets[b, i, w_idx, 2] = float(idx.max().item()) / float(denom) + + return targets + + def _interval_loss( + self, + pred: torch.Tensor, + target: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + pred/target shape: [B, N, 3, 3(conf,start,end)] + """ + pc = pred[..., 0] + ps = pred[..., 1] + pe = pred[..., 2] + + tc = target[..., 0] + ts = target[..., 1] + te = target[..., 2] + + diff_c = torch.abs(pc - tc) + cl = torch.where( + diff_c < self.conf_tolerance, + torch.zeros_like(diff_c), + (pc - tc) ** 2, + ) + + ss = (ps - ts) ** 2 + (pe - te) ** 2 + sel = torch.where( + ss < self.se_tolerance, + torch.zeros_like(ss), + ss * tc, + ) + + loss = (cl + sel).mean() + return loss, cl.mean(), sel.mean() + + def _predict_intervals( + self, signal: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + raw_logits: [B, N, 3, 3] + pred: [B, N, 3, 3] (sigmoid) + """ + x = self._normalize_signal_shape(signal) + bsz, _, seq_len = x.shape + n_intervals = max(1, math.ceil(seq_len / self.interval_size)) + + feat = self.stem(x) + feat = self.backbone(feat) + feat = self.proj(feat) + feat = F.adaptive_avg_pool1d(feat, n_intervals) + logits = self.head(feat) # [B, 9, N] + + raw_logits = logits.permute(0, 2, 1).reshape(bsz, n_intervals, 3, 3) + pred = torch.sigmoid(raw_logits) + return raw_logits, pred + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + signal = self._extract_tensor(kwargs[self.signal_key]).to(self.device) + raw_logits, pred = self._predict_intervals(signal) + + if self.mask_key in kwargs: + mask = self._extract_tensor(kwargs[self.mask_key]).to(self.device) + targets = self._build_interval_targets(mask, n_intervals=pred.shape[1]) + loss, cl_mean, sel_mean = self._interval_loss(pred, targets) + else: + targets = torch.zeros_like(pred, device=self.device) + loss = torch.tensor(0.0, device=self.device) + cl_mean = torch.tensor(0.0, device=self.device) + sel_mean = torch.tensor(0.0, device=self.device) + + return { + "loss": loss, + "y_prob": pred, # interval-level predictions + "y_true": targets, # interval-level targets + "logit": raw_logits, + "cl_loss": cl_mean, + "sel_loss": sel_mean, + } + + def forward_from_embedding(self, **kwargs) -> Dict[str, torch.Tensor]: + """ + For compatibility with interpretability utilities. + This model consumes dense signal tensors directly. + """ + return self.forward(**kwargs) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..bebf1a126 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,7 +1,7 @@ from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot -from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .bmd_hs_disease_classification import BMDHSDiseaseClassification +from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .cardiology_detect import ( cardiology_isAD_fn, cardiology_isAR_fn, @@ -21,6 +21,13 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) +from .ecg_delineation import ( + ECGDelineationLUDB, + ECGDelineationQTDB, + ecg_delineation_ludb_fn, + get_ecg_delineation_ludb_task, + get_ecg_delineation_qtdb_task, +) from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, @@ -40,11 +47,11 @@ MultimodalMortalityPredictionMIMIC3, MultimodalMortalityPredictionMIMIC4, ) -from .survival_preprocess_support2 import SurvivalPreprocessSupport2 from .mortality_prediction_stagenet_mimic4 import ( MortalityPredictionStageNetMIMIC4, ) from .patient_linkage import patient_linkage_mimic3_fn +from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task from .readmission_prediction import ( ReadmissionPredictionEICU, ReadmissionPredictionMIMIC3, @@ -57,12 +64,9 @@ sleep_staging_sleepedf_fn, ) from .sleep_staging_v2 import SleepStagingSleepEDF -from .temple_university_EEG_tasks import ( - EEGEventsTUEV, - EEGAbnormalTUAB -) +from .survival_preprocess_support2 import SurvivalPreprocessSupport2 +from .temple_university_EEG_tasks import EEGAbnormalTUAB, EEGEventsTUEV from .variant_classification import ( MutationPathogenicityPrediction, VariantClassificationClinVar, ) -from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task diff --git a/pyhealth/tasks/ecg_delineation.py b/pyhealth/tasks/ecg_delineation.py new file mode 100644 index 000000000..ca28e3af5 --- /dev/null +++ b/pyhealth/tasks/ecg_delineation.py @@ -0,0 +1,405 @@ +"""ECG wave delineation tasks for LUDB and QTDB. + +This module provides task classes compatible with the BaseDataset + Patient/Event +pipeline for ECG waveform delineation. + +Label encoding +-------------- +0: background +1: P wave +2: QRS complex +3: T wave +""" + +from __future__ import annotations + +import importlib +from pathlib import Path +from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union + +import numpy as np + +from .base_task import BaseTask + +try: + wfdb_module = importlib.import_module("wfdb") +except Exception: # pragma: no cover + wfdb_module = None + + +LUDB_LEADS: List[str] = [ + "i", + "ii", + "iii", + "avr", + "avl", + "avf", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", +] + +QTDB_LEADS: List[str] = ["0", "1"] + +DEFAULT_PULSE_WINDOW: int = 250 + + +def _safe_str(value: Any) -> Optional[str]: + """Convert value to stripped string and normalize missing-like tokens.""" + if value is None: + return None + text = str(value).strip() + if text == "" or text.lower() in {"none", "nan"}: + return None + return text + + +def _extract_extension(value: Any) -> Optional[str]: + """Extract WFDB annotation extension from a path or extension-like value.""" + text = _safe_str(value) + if text is None: + return None + + # Path-like input: take suffix as extension + suffix = Path(text).suffix + if suffix: + return suffix.lstrip(".") + + # Extension-like input (e.g., "i", "pu0") + return text + + +def _parse_annotations(record_base_path: str, extension: str) -> Dict[str, List[dict]]: + """Read WFDB delineation annotations for one record/extension pair. + + Parser expects triplets: + "(" -> onset + wave symbol -> peak + ")" -> offset + + Supported wave peak symbols: + - P-wave: p / P + - QRS: N / n + - T-wave: t / T + """ + try: + if wfdb_module is None: + return {"P": [], "QRS": [], "T": []} + ann = wfdb_module.rdann(record_base_path, extension) + except Exception: + return {"P": [], "QRS": [], "T": []} + + symbols = ann.symbol + samples = ann.sample + waves: Dict[str, List[dict]] = {"P": [], "QRS": [], "T": []} + wave_map = { + "p": "P", + "P": "P", + "N": "QRS", + "n": "QRS", + "t": "T", + "T": "T", + } + + i = 0 + while i < len(symbols): + if symbols[i] == "(" and i + 2 < len(symbols) and symbols[i + 2] == ")": + peak_sym = symbols[i + 1] + if peak_sym in wave_map: + waves[wave_map[peak_sym]].append( + { + "onset": int(samples[i]), + "peak": int(samples[i + 1]), + "offset": int(samples[i + 2]), + } + ) + i += 3 + else: + i += 1 + + return waves + + +def _build_segmentation_mask( + signal_length: int, waves: Dict[str, List[dict]] +) -> np.ndarray: + """Build point-wise segmentation mask with values in {0, 1, 2, 3}.""" + mask = np.zeros(signal_length, dtype=np.int64) + label_map = {"P": 1, "QRS": 2, "T": 3} + + for wave_name, label in label_map.items(): + for w in waves[wave_name]: + onset = max(0, int(w["onset"])) + offset = min(signal_length - 1, int(w["offset"])) + if onset <= offset: + mask[onset : offset + 1] = label + + return mask + + +def _resolve_lead_index(wfdb_record: Any, lead: str, fallback_idx: int) -> int: + """Resolve lead index robustly from WFDB signal names.""" + sig_names = [str(name).strip().lower() for name in wfdb_record.sig_name] + lead_norm = str(lead).strip().lower() + if lead_norm in sig_names: + return sig_names.index(lead_norm) + return fallback_idx + + +def _has_any_wave(waves: Dict[str, List[dict]]) -> bool: + """Return whether parsed waves contain at least one delineated segment.""" + return any(len(waves[k]) > 0 for k in ("P", "QRS", "T")) + + +class ECGDelineationTask(BaseTask): + """Generic ECG delineation task for event-table ECG datasets. + + This task expects one event row per patient that includes: + - `signal_file`: absolute WFDB base record path (without extension) + - lead-related metadata used to resolve annotation extensions + """ + + task_name: str = "ecg_delineation" + input_schema: Dict[str, Union[str, Type]] = {"signal": "tensor"} + output_schema: Dict[str, Union[str, Type]] = { + "mask": "tensor", + "label": "multiclass", + } + + def __init__( + self, + event_type: str, + leads: Sequence[str], + lead_field_map: Optional[Mapping[str, str]] = None, + annotation_field_map: Optional[Mapping[str, Sequence[str]]] = None, + annotation_extension_map: Optional[Mapping[str, Sequence[str]]] = None, + split_by_pulse: bool = False, + pulse_window: int = DEFAULT_PULSE_WINDOW, + filter_incomplete_pulses: bool = False, + ) -> None: + super().__init__() + if pulse_window <= 0: + raise ValueError("pulse_window must be a positive integer.") + self.event_type = event_type + self.leads = list(leads) + self.lead_field_map = dict(lead_field_map or {}) + self.annotation_field_map = { + k: list(v) for k, v in (annotation_field_map or {}).items() + } + self.annotation_extension_map = { + k: list(v) for k, v in (annotation_extension_map or {}).items() + } + self.split_by_pulse = bool(split_by_pulse) + self.pulse_window = int(pulse_window) + self.filter_incomplete_pulses = bool(filter_incomplete_pulses) + + @staticmethod + def _is_complete_pulse_annotation(pulse_mask: np.ndarray) -> bool: + """Return True if pulse window contains P, QRS, and T labels.""" + if pulse_mask.size == 0: + return False + labels = set(np.unique(pulse_mask).tolist()) + return {1, 2, 3}.issubset(labels) + + def _candidate_extensions(self, event: Any, lead: str) -> List[str]: + """Build candidate annotation extensions for a given lead.""" + candidates: List[str] = [] + + # 1) Explicit annotation field list for this lead + for field in self.annotation_field_map.get(lead, []): + ext = _extract_extension(getattr(event, field, None)) + if ext: + candidates.append(ext) + + # 2) Explicit extension preference list for this lead + candidates.extend(self.annotation_extension_map.get(lead, [])) + + # 3) Lead field (can be extension or path) + lead_field = self.lead_field_map.get(lead) + if lead_field is not None: + ext = _extract_extension(getattr(event, lead_field, None)) + if ext: + candidates.append(ext) + + # 4) Fallback: use lead itself as extension + candidates.append(lead) + + # Deduplicate while preserving order + seen = set() + unique: List[str] = [] + for ext in candidates: + if ext not in seen: + seen.add(ext) + unique.append(ext) + return unique + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + events = patient.get_events(self.event_type) + if len(events) == 0: + return [] + + event = events[0] + pid = str(patient.patient_id) + visit_id = getattr(event, "visit_id", "ecg") + record_base_path = _safe_str(getattr(event, "signal_file", None)) + if record_base_path is None: + return [] + + try: + if wfdb_module is None: + return [] + wfdb_record = wfdb_module.rdrecord(record_base_path) + except Exception: + return [] + + samples: List[Dict[str, Any]] = [] + + for lead_fallback_idx, lead in enumerate(self.leads): + lead_idx = _resolve_lead_index(wfdb_record, lead, lead_fallback_idx) + if lead_idx >= wfdb_record.p_signal.shape[1]: + continue + + signal = wfdb_record.p_signal[:, lead_idx].astype(np.float32) + + # choose first candidate extension that yields at least one wave + waves = {"P": [], "QRS": [], "T": []} + for ext in self._candidate_extensions(event, lead): + parsed = _parse_annotations(record_base_path, ext) + if _has_any_wave(parsed): + waves = parsed + break + + # Skip if no usable annotation exists + if not _has_any_wave(waves): + continue + + mask = _build_segmentation_mask(len(signal), waves) + + if self.split_by_pulse: + for pulse_idx, qrs in enumerate(waves["QRS"]): + r_peak = int(qrs["peak"]) + start = r_peak - self.pulse_window + end = r_peak + self.pulse_window + if start < 0 or end > len(signal): + continue + + pulse_signal = signal[start:end] + pulse_mask = mask[start:end] + if pulse_signal.size == 0: + continue + + if ( + self.filter_incomplete_pulses + and not self._is_complete_pulse_annotation(pulse_mask) + ): + continue + + samples.append( + { + "patient_id": pid, + "visit_id": visit_id, + "record_id": f"{pid}_{lead}_{pulse_idx}", + "lead": lead, + "signal": pulse_signal[np.newaxis, :], + "mask": pulse_mask, + "label": int(np.bincount(pulse_mask).argmax()), + } + ) + else: + samples.append( + { + "patient_id": pid, + "visit_id": visit_id, + "record_id": f"{pid}_{lead}", + "lead": lead, + "signal": signal[np.newaxis, :], + "mask": mask, + "label": int(np.bincount(mask).argmax()), + } + ) + + return samples + + +class ECGDelineationLUDB(ECGDelineationTask): + """LUDB-specific ECG delineation task.""" + + task_name: str = "ecg_delineation_ludb" + + def __init__( + self, + split_by_pulse: bool = False, + pulse_window: int = DEFAULT_PULSE_WINDOW, + filter_incomplete_pulses: bool = False, + ) -> None: + super().__init__( + event_type="ludb", + leads=LUDB_LEADS, + lead_field_map={lead: f"lead_{lead}" for lead in LUDB_LEADS}, + split_by_pulse=split_by_pulse, + pulse_window=pulse_window, + filter_incomplete_pulses=filter_incomplete_pulses, + ) + + +class ECGDelineationQTDB(ECGDelineationTask): + """QTDB-specific ECG delineation task (two leads).""" + + task_name: str = "ecg_delineation_qtdb" + + def __init__( + self, + split_by_pulse: bool = False, + pulse_window: int = DEFAULT_PULSE_WINDOW, + filter_incomplete_pulses: bool = False, + ) -> None: + super().__init__( + event_type="qtdb", + leads=QTDB_LEADS, + lead_field_map={"0": "lead_0", "1": "lead_1"}, + # Prefer QTDB lead-specific automatic delineations first. + annotation_field_map={ + "0": ["ann_pu0", "ann_q1c", "ann_qt1", "ann_man", "ann_atr"], + "1": ["ann_pu1", "ann_q1c", "ann_qt1", "ann_man", "ann_atr"], + }, + annotation_extension_map={ + "0": ["pu0", "q1c", "qt1", "man", "atr"], + "1": ["pu1", "q1c", "qt1", "man", "atr"], + }, + split_by_pulse=split_by_pulse, + pulse_window=pulse_window, + filter_incomplete_pulses=filter_incomplete_pulses, + ) + + +def get_ecg_delineation_ludb_task( + split_by_pulse: bool = False, + pulse_window: int = DEFAULT_PULSE_WINDOW, + filter_incomplete_pulses: bool = False, +) -> ECGDelineationLUDB: + """Factory helper for configurable LUDB delineation task.""" + return ECGDelineationLUDB( + split_by_pulse=split_by_pulse, + pulse_window=pulse_window, + filter_incomplete_pulses=filter_incomplete_pulses, + ) + + +def get_ecg_delineation_qtdb_task( + split_by_pulse: bool = False, + pulse_window: int = DEFAULT_PULSE_WINDOW, + filter_incomplete_pulses: bool = False, +) -> ECGDelineationQTDB: + """Factory helper for configurable QTDB delineation task.""" + return ECGDelineationQTDB( + split_by_pulse=split_by_pulse, + pulse_window=pulse_window, + filter_incomplete_pulses=filter_incomplete_pulses, + ) + + +# Backward-compatible symbol kept for existing imports/usages. +ecg_delineation_ludb_fn = ECGDelineationLUDB(split_by_pulse=False) diff --git a/tests/test_ecg_code.py b/tests/test_ecg_code.py new file mode 100644 index 000000000..3d084dda2 --- /dev/null +++ b/tests/test_ecg_code.py @@ -0,0 +1,166 @@ +"""Fast focused unit tests for ECGCODE model forward pass and gradients.""" + +from __future__ import annotations + +from typing import Any, cast + +import pytest +import torch + +from pyhealth.models.ecg_code import ECGCODE + + +class _DummySampleDataset: + """Minimal dataset shim compatible with BaseModel.""" + + input_schema = {"signal": "tensor"} + output_schema = {"mask": "tensor"} + + +def _make_model(interval_size: int = 16, width_mult: float = 0.25) -> ECGCODE: + # width_mult=0.25 keeps tests very fast while preserving behavior. + return ECGCODE( + dataset=cast(Any, _DummySampleDataset()), + interval_size=interval_size, + width_mult=width_mult, + ) + + +def test_instantiation_and_invalid_interval() -> None: + model = _make_model(interval_size=16) + assert model.interval_size == 16 + + with pytest.raises(ValueError, match="interval_size must be positive"): + _make_model(interval_size=0) + + +def test_extract_tensor_supports_tensor_and_tuple() -> None: + model = _make_model() + x = torch.randn(2, 64) + + assert model._extract_tensor(x) is x + assert model._extract_tensor((None, x, "ignored")) is x + + with pytest.raises( + ValueError, match="Expected a tensor or tuple containing a tensor" + ): + model._extract_tensor(("a", "b")) + + +def test_normalize_signal_shape_variants() -> None: + model = _make_model() + + x1 = torch.randn(64) # [T] + y1 = model._normalize_signal_shape(x1) + assert y1.shape == (1, 1, 64) + + x2 = torch.randn(2, 64) # [B, T] + y2 = model._normalize_signal_shape(x2) + assert y2.shape == (2, 1, 64) + + x3 = torch.randn(2, 64, 1) # [B, T, C] -> transpose to [B, C, T] + y3 = model._normalize_signal_shape(x3) + assert y3.shape == (2, 1, 64) + + with pytest.raises(ValueError, match="Unsupported signal shape"): + model._normalize_signal_shape(torch.randn(2, 3, 4, 5)) + + +def test_build_interval_targets_shape_and_presence_flags() -> None: + model = _make_model(interval_size=16) + + # B=1, T=32 => N=2 intervals. + # interval 0 has P in [1..3], QRS in [5..7], T absent + # interval 1 has T in [20..24], others absent + mask = (torch.randn(1, 32).abs() * 0).long() + mask[0, 1:4] = 1 + mask[0, 5:8] = 2 + mask[0, 20:25] = 3 + + target = model._build_interval_targets(mask, n_intervals=2) + assert target.shape == (1, 2, 3, 3) + + # confidence slots + conf = target[0, :, :, 0] + # interval 0: P/QRS present, T absent + assert conf[0, 0].item() == 1.0 + assert conf[0, 1].item() == 1.0 + assert conf[0, 2].item() == 0.0 + # interval 1: T present only + assert conf[1, 0].item() == 0.0 + assert conf[1, 1].item() == 0.0 + assert conf[1, 2].item() == 1.0 + + +def test_forward_with_mask_output_shapes_and_range() -> None: + torch.manual_seed(7) + model = _make_model(interval_size=16) + model.eval() + + signal = torch.randn(2, 1, 64) + mask = (torch.randn(2, 64).abs() * 4).long() + + out = model(signal=signal, mask=mask) + + assert {"loss", "y_prob", "y_true", "logit", "cl_loss", "sel_loss"} <= set( + out.keys() + ) + assert out["loss"].ndim == 0 + assert out["cl_loss"].ndim == 0 + assert out["sel_loss"].ndim == 0 + + # T=64, interval_size=16 => N=4 + assert out["logit"].shape == (2, 4, 3, 3) + assert out["y_prob"].shape == (2, 4, 3, 3) + assert out["y_true"].shape == (2, 4, 3, 3) + + # sigmoid output in [0,1] + assert out["y_prob"].min().item() >= 0.0 + assert out["y_prob"].max().item() <= 1.0 + + +def test_forward_without_mask_returns_zero_losses_and_targets() -> None: + model = _make_model(interval_size=16) + signal = torch.randn(2, 64) # [B, T] also supported + + out = model(signal=signal) + + assert out["y_prob"].shape == (2, 4, 3, 3) + assert out["y_true"].abs().sum().item() == pytest.approx(0.0) + assert out["loss"].item() == pytest.approx(0.0) + assert out["cl_loss"].item() == pytest.approx(0.0) + assert out["sel_loss"].item() == pytest.approx(0.0) + + +def test_backward_computes_finite_gradients() -> None: + torch.manual_seed(11) + model = _make_model(interval_size=16) + model.train() + + signal = torch.randn(2, 1, 64) + mask = (torch.randn(2, 64).abs() * 4).long() + + out = model(signal=signal, mask=mask) + out["loss"].backward() + + grads = [p.grad for p in model.parameters() if p.requires_grad] + assert any(g is not None for g in grads), ( + "Expected at least one parameter to receive gradient" + ) + assert all( + (((g == g) & (g.abs() < float("inf"))).all().item()) + for g in grads + if g is not None + ) + + +def test_forward_from_embedding_is_alias_of_forward() -> None: + model = _make_model(interval_size=16) + signal = torch.randn(2, 1, 64) + mask = (torch.randn(2, 64).abs() * 4).long() + + out1 = model(signal=signal, mask=mask) + out2 = model.forward_from_embedding(signal=signal, mask=mask) + + assert out1["y_prob"].shape == out2["y_prob"].shape + assert out1["logit"].shape == out2["logit"].shape diff --git a/tests/test_ecg_delineation.py b/tests/test_ecg_delineation.py new file mode 100644 index 000000000..dde34249c --- /dev/null +++ b/tests/test_ecg_delineation.py @@ -0,0 +1,260 @@ +"""Focused fast tests for ECG delineation tasks on synthetic WFDB records.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +import wfdb + +from pyhealth.tasks.ecg_delineation import ( + ECGDelineationLUDB, + ECGDelineationQTDB, + ECGDelineationTask, + _build_segmentation_mask, + _extract_extension, + _has_any_wave, + _parse_annotations, + _resolve_lead_index, + _safe_str, +) + +FS = 250 +SIGNAL_LEN = 1000 + + +def _write_record(base_dir: Path, record_name: str, sig_names: list[str]) -> Path: + t = np.linspace(0, 4 * np.pi, SIGNAL_LEN, dtype=np.float64) + sig = np.stack([0.01 * np.sin(t + i) for i in range(len(sig_names))], axis=1) + wfdb.wrsamp( + record_name=record_name, + fs=FS, + units=["mV"] * len(sig_names), + sig_name=sig_names, + p_signal=sig, + write_dir=str(base_dir), + ) + return base_dir / record_name + + +def _write_triplet_annotation(base_dir: Path, record_name: str, ext: str) -> None: + samples = np.array([90, 100, 110, 240, 250, 260, 390, 400, 410], dtype=np.int32) + symbols = ["(", "p", ")", "(", "N", ")", "(", "t", ")"] + wfdb.wrann( + record_name=record_name, + extension=ext, + sample=samples, + symbol=symbols, + write_dir=str(base_dir), + ) + + +def _write_qrs_only_annotation(base_dir: Path, record_name: str, ext: str) -> None: + samples = np.array([240, 250, 260], dtype=np.int32) + symbols = ["(", "N", ")"] + wfdb.wrann( + record_name=record_name, + extension=ext, + sample=samples, + symbol=symbols, + write_dir=str(base_dir), + ) + + +@dataclass +class _Event: + visit_id: str + signal_file: str + lead_i: str | None = None + lead_ii: str | None = None + lead_0: str | None = None + lead_1: str | None = None + ann_pu0: str | None = None + ann_pu1: str | None = None + ann_q1c: str | None = None + ann_qt1: str | None = None + ann_man: str | None = None + ann_atr: str | None = None + + +class _Patient: + def __init__(self, patient_id: str, event: _Event): + self.patient_id = patient_id + self._event = event + + def get_events(self, event_type: str | None = None, *args: Any, **kwargs: Any): + if event_type is not None: + assert event_type in {"ludb", "qtdb", "toy"} + return [self._event] + + +@pytest.fixture() +def ludb_record(tmp_path: Path): + base = _write_record(tmp_path, "1", ["i", "ii"]) + _write_triplet_annotation(tmp_path, "1", "i") + _write_triplet_annotation(tmp_path, "1", "ii") + return tmp_path, base + + +@pytest.fixture() +def qtdb_record(tmp_path: Path): + base = _write_record(tmp_path, "sel100", ["0", "1"]) + # wfdb annotation extensions must be alphabetic-only; use custom + # extensions to verify field-based fallback in ECGDelineationQTDB. + _write_triplet_annotation(tmp_path, "sel100", "pua") + _write_triplet_annotation(tmp_path, "sel100", "pub") + return tmp_path, base + + +class TestHelpers: + def test_safe_str_and_extract_extension(self): + assert _safe_str(None) is None + assert _safe_str(" nan ") is None + assert _safe_str(" ok ") == "ok" + + assert _extract_extension("i") == "i" + assert _extract_extension("/tmp/a/b/1.pu0") == "pu0" + assert _extract_extension("") is None + + def test_has_any_wave(self): + assert _has_any_wave({"P": [], "QRS": [], "T": []}) is False + assert ( + _has_any_wave( + {"P": [{"onset": 1, "peak": 2, "offset": 3}], "QRS": [], "T": []} + ) + is True + ) + + def test_build_segmentation_mask_clamps_and_labels(self): + waves = { + "P": [{"onset": -10, "peak": 2, "offset": 3}], + "QRS": [{"onset": 5, "peak": 6, "offset": 1000}], + "T": [], + } + mask = _build_segmentation_mask(10, waves) + assert mask.shape == (10,) + assert mask.dtype == np.int64 + assert np.all(mask[0:4] == 1) + assert np.all(mask[5:10] == 2) + + def test_parse_annotations_triplets(self, ludb_record): + _, base = ludb_record + waves = _parse_annotations(str(base), "i") + assert set(waves.keys()) == {"P", "QRS", "T"} + assert len(waves["P"]) == len(waves["QRS"]) == len(waves["T"]) == 1 + assert ( + waves["QRS"][0]["onset"] + < waves["QRS"][0]["peak"] + < waves["QRS"][0]["offset"] + ) + + def test_resolve_lead_index_with_fallback(self, ludb_record): + _, base = ludb_record + record = wfdb.rdrecord(str(base)) + assert _resolve_lead_index(record, "ii", fallback_idx=0) == 1 + assert _resolve_lead_index(record, "v6", fallback_idx=0) == 0 + + +class TestECGDelineationTask: + def test_invalid_pulse_window_raises(self): + with pytest.raises(ValueError, match="pulse_window must be a positive integer"): + ECGDelineationTask(event_type="toy", leads=["i"], pulse_window=0) + + def test_candidate_extension_priority(self, ludb_record): + _, base = ludb_record + event = _Event( + visit_id="ecg", + signal_file=str(base), + lead_i=str(base.with_suffix(".i")), + ) + task = ECGDelineationTask( + event_type="toy", + leads=["i"], + lead_field_map={"i": "lead_i"}, + annotation_extension_map={"i": ["i", "man", "i"]}, + ) + # should deduplicate while preserving order + exts = task._candidate_extensions(event, "i") + assert exts[0] == "i" + assert exts.count("i") == 1 + + +class TestLUDBTask: + def test_full_record_returns_expected_samples(self, ludb_record): + _, base = ludb_record + event = _Event( + visit_id="ecg", + signal_file=str(base), + lead_i=str(base.with_suffix(".i")), + lead_ii=str(base.with_suffix(".ii")), + ) + patient = _Patient("1", event) + task = ECGDelineationLUDB(split_by_pulse=False) + + samples = task(patient) + assert len(samples) == 2 # i + ii + for s in samples: + assert s["signal"].shape == (1, SIGNAL_LEN) + assert s["mask"].shape == (SIGNAL_LEN,) + assert s["label"] in (0, 1, 2, 3) + + def test_pulse_mode_shape_and_filtering(self, tmp_path: Path): + base = _write_record(tmp_path, "2", ["i", "ii"]) + _write_qrs_only_annotation(tmp_path, "2", "i") + + event = _Event( + visit_id="ecg", + signal_file=str(base), + lead_i=str(base.with_suffix(".i")), + ) + patient = _Patient("2", event) + + loose = ECGDelineationLUDB( + split_by_pulse=True, pulse_window=120, filter_incomplete_pulses=False + )(patient) + strict = ECGDelineationLUDB( + split_by_pulse=True, pulse_window=120, filter_incomplete_pulses=True + )(patient) + + assert len(loose) >= 1 + assert len(strict) == 0 + assert loose[0]["signal"].shape == (1, 240) + assert loose[0]["mask"].shape == (240,) + + def test_bad_signal_path_returns_empty(self): + patient = _Patient( + "999", + _Event( + visit_id="ecg", + signal_file="/does/not/exist/record", + lead_i="/does/not/exist/record.i", + ), + ) + assert ECGDelineationLUDB(split_by_pulse=False)(patient) == [] + + +class TestQTDBTask: + def test_qtdb_uses_annotation_field_extensions_when_present(self, qtdb_record): + _, base = qtdb_record + event = _Event( + visit_id="ecg", + signal_file=str(base), + lead_0="0", + lead_1="1", + ann_pu0=str(base.with_suffix(".pua")), + ann_pu1=str(base.with_suffix(".pub")), + ) + patient = _Patient("sel100", event) + task = ECGDelineationQTDB(split_by_pulse=False) + + samples = task(patient) + assert len(samples) == 2 + leads = {s["lead"] for s in samples} + assert leads == {"0", "1"} + for s in samples: + assert s["signal"].shape == (1, SIGNAL_LEN) + assert s["mask"].shape == (SIGNAL_LEN,) + assert s["label"] in (0, 1, 2, 3) diff --git a/tests/test_ludb.py b/tests/test_ludb.py new file mode 100644 index 000000000..86a1ffe55 --- /dev/null +++ b/tests/test_ludb.py @@ -0,0 +1,397 @@ +"""Fast synthetic tests for LUDB dataset, ECG delineation task, and ECGCODE model.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import cast + +import numpy as np +import pandas as pd +import pytest +import torch +import wfdb + +from pyhealth.data.data import Event +from pyhealth.datasets.ludb import ( + FS as LUDB_FS, +) +from pyhealth.datasets.ludb import ( + LEADS, + N_SAMPLES, + LUDBDataset, + get_stratified_ludb_split, +) +from pyhealth.tasks.ecg_delineation import ( + ECGDelineationLUDB, + _build_segmentation_mask, + _extract_extension, + _has_any_wave, + _parse_annotations, + _safe_str, + ecg_delineation_ludb_fn, + get_ecg_delineation_ludb_task, +) + +SIGNAL_LEN = LUDB_FS * 10 # 10 seconds + + +def _write_synthetic_record(data_dir: Path, record_id: int) -> None: + """Write one synthetic 12-lead WFDB record.""" + pid = str(record_id) + n_leads = len(LEADS) + t = np.linspace(0, 2 * np.pi, SIGNAL_LEN) + signal = (0.01 * np.sin(t)[:, np.newaxis] * np.ones((1, n_leads))).astype( + np.float64 + ) + + wfdb.wrsamp( + record_name=pid, + fs=LUDB_FS, + units=["mV"] * n_leads, + sig_name=list(LEADS), + p_signal=signal, + write_dir=str(data_dir), + ) + + +def _write_triplet_annotation(data_dir: Path, record_id: int, lead: str) -> None: + """Write one P-QRS-T annotation triplet.""" + pid = str(record_id) + ann_samples = np.array( + [ + 90, + 100, + 110, # P + 240, + 250, + 260, # QRS + 390, + 400, + 410, # T + ], + dtype=np.int32, + ) + ann_symbols = ["(", "p", ")", "(", "N", ")", "(", "t", ")"] + + wfdb.wrann( + record_name=pid, + extension=lead, + sample=ann_samples, + symbol=ann_symbols, + write_dir=str(data_dir), + ) + + +def _write_qrs_only_annotation(data_dir: Path, record_id: int, lead: str) -> None: + """Write QRS-only annotations (for incomplete pulse edge-case tests).""" + pid = str(record_id) + ann_samples = np.array([240, 250, 260], dtype=np.int32) + ann_symbols = ["(", "N", ")"] + + wfdb.wrann( + record_name=pid, + extension=lead, + sample=ann_samples, + symbol=ann_symbols, + write_dir=str(data_dir), + ) + + +@dataclass +class _MockEvent: + visit_id: str + signal_file: str + lead_i: str | None = None + lead_ii: str | None = None + lead_iii: str | None = None + lead_avr: str | None = None + lead_avl: str | None = None + lead_avf: str | None = None + lead_v1: str | None = None + lead_v2: str | None = None + lead_v3: str | None = None + lead_v4: str | None = None + lead_v5: str | None = None + lead_v6: str | None = None + + +class _MockPatient: + def __init__(self, patient_id: str, event: _MockEvent): + self.patient_id = patient_id + self._event = event + + def get_events(self, event_type=None, *args, **kwargs): + if event_type is not None: + assert event_type == "ludb" + return [self._event] + + +@pytest.fixture(scope="module") +def synthetic_ludb_root(tmp_path_factory) -> Path: + """Create a tiny LUDB-like root with 3 records and minimal annotations.""" + root = tmp_path_factory.mktemp("ludb_synth") + for rid in (1, 2, 3): + _write_synthetic_record(root, rid) + + # Record 1: complete on i/ii + _write_triplet_annotation(root, 1, "i") + _write_triplet_annotation(root, 1, "ii") + + # Record 2: incomplete on i (QRS-only), no ii + _write_qrs_only_annotation(root, 2, "i") + + # Record 3: complete on i only + _write_triplet_annotation(root, 3, "i") + + # Optional LUDB CSV for metadata enrichment & split tests + pd.DataFrame( + { + "ID": [1, 2, 3], + "Rhythms": ["Sinus rhythm", "Atrial fibrillation", "Sinus rhythm"], + "Electric axis of the heart": [ + "Electric axis of the heart: normal", + "Electric axis of the heart: normal", + "Electric axis of the heart: left", + ], + } + ).to_csv(root / "ludb.csv", index=False) + + return root + + +@pytest.fixture() +def strat_csv_path(tmp_path: Path) -> Path: + path = tmp_path / "ludb.csv" + # Two groups of 10 each -> deterministic 8/1/1 per group + df = pd.DataFrame( + { + "ID": list(range(1, 21)), + "Rhythms": ["Sinus rhythm"] * 10 + ["Atrial fibrillation"] * 10, + "Electric axis of the heart": ["Electric axis of the heart: normal"] * 20, + } + ) + df.to_csv(path, index=False) + return path + + +class TestHelpersAndParsing: + def test_safe_str_and_extract_extension(self): + assert _safe_str(None) is None + assert _safe_str(" nan ") is None + assert _safe_str(" abc ") == "abc" + + assert _extract_extension("i") == "i" + assert _extract_extension("/tmp/1.i") == "i" + assert _extract_extension("") is None + + def test_has_any_wave(self): + assert _has_any_wave({"P": [], "QRS": [], "T": []}) is False + assert ( + _has_any_wave( + {"P": [{"onset": 1, "peak": 2, "offset": 3}], "QRS": [], "T": []} + ) + is True + ) + + def test_parse_annotations_returns_expected_keys(self, synthetic_ludb_root: Path): + waves = _parse_annotations(str(synthetic_ludb_root / "1"), "i") + assert set(waves.keys()) == {"P", "QRS", "T"} + + def test_parse_annotations_triplets(self, synthetic_ludb_root: Path): + waves = _parse_annotations(str(synthetic_ludb_root / "1"), "i") + for wave_name in ("P", "QRS", "T"): + assert len(waves[wave_name]) == 1 + w = waves[wave_name][0] + assert set(w.keys()) == {"onset", "peak", "offset"} + assert w["onset"] < w["peak"] < w["offset"] + + def test_missing_annotation_returns_empty(self, synthetic_ludb_root: Path): + waves = _parse_annotations(str(synthetic_ludb_root / "1"), "v6") + assert waves == {"P": [], "QRS": [], "T": []} + + +class TestSegmentationMask: + def test_shape_dtype_and_values(self): + waves = { + "P": [{"onset": 2, "peak": 3, "offset": 4}], + "QRS": [{"onset": 7, "peak": 8, "offset": 9}], + "T": [{"onset": 12, "peak": 13, "offset": 14}], + } + mask = _build_segmentation_mask(20, waves) + assert mask.shape == (20,) + assert mask.dtype == np.int64 + assert np.all(mask[2:5] == 1) + assert np.all(mask[7:10] == 2) + assert np.all(mask[12:15] == 3) + assert mask[0] == 0 + + def test_clamps_out_of_range_boundaries(self): + waves = { + "P": [{"onset": -5, "peak": 1, "offset": 2}], + "QRS": [{"onset": 8, "peak": 9, "offset": 30}], + "T": [], + } + mask = _build_segmentation_mask(10, waves) + assert np.all(mask[0:3] == 1) + assert np.all(mask[8:10] == 2) + + +class TestECGDelineationLUDBTask: + def _make_patient(self, root: Path, pid: str = "1") -> _MockPatient: + base = root / pid + event = _MockEvent( + visit_id="ecg", + signal_file=str(base), + lead_i=str(base.with_suffix(".i")), + lead_ii=str(base.with_suffix(".ii")), + ) + return _MockPatient(patient_id=pid, event=event) + + def test_full_record_mode_returns_samples(self, synthetic_ludb_root: Path): + patient = self._make_patient(synthetic_ludb_root, pid="1") + task = ECGDelineationLUDB(split_by_pulse=False) + samples = task(patient) + + # Record 1 has annotations on i and ii + assert len(samples) == 2 + for s in samples: + assert { + "patient_id", + "visit_id", + "record_id", + "lead", + "signal", + "mask", + "label", + } <= set(s.keys()) + assert s["signal"].shape == (1, SIGNAL_LEN) + assert s["mask"].shape == (SIGNAL_LEN,) + assert s["label"] in (0, 1, 2, 3) + + def test_pulse_mode_returns_fixed_windows(self, synthetic_ludb_root: Path): + patient = self._make_patient(synthetic_ludb_root, pid="1") + task = ECGDelineationLUDB(split_by_pulse=True, pulse_window=250) + samples = task(patient) + assert len(samples) > 0 + for s in samples: + assert s["signal"].shape == (1, 500) + assert s["mask"].shape == (500,) + + def test_incomplete_pulse_filtering_edge_case(self, synthetic_ludb_root: Path): + patient = self._make_patient(synthetic_ludb_root, pid="2") # QRS-only on lead i + loose = ECGDelineationLUDB( + split_by_pulse=True, pulse_window=250, filter_incomplete_pulses=False + )(patient) + strict = ECGDelineationLUDB( + split_by_pulse=True, pulse_window=250, filter_incomplete_pulses=True + )(patient) + assert len(loose) >= 1 + assert len(strict) == 0 + + def test_bad_record_returns_empty(self): + patient = _MockPatient( + patient_id="999", + event=_MockEvent( + visit_id="ecg", + signal_file="/non/existent/record", + lead_i="/non/existent/record.i", + ), + ) + task = ECGDelineationLUDB(split_by_pulse=False) + assert task(patient) == [] + + def test_backward_compat_alias_and_factory(self, synthetic_ludb_root: Path): + patient = self._make_patient(synthetic_ludb_root, pid="1") + assert len(ecg_delineation_ludb_fn(patient)) == 2 + + task = get_ecg_delineation_ludb_task(split_by_pulse=True, pulse_window=125) + assert isinstance(task, ECGDelineationLUDB) + assert task.split_by_pulse is True + assert task.pulse_window == 125 + + +class TestLUDBDatasetModern: + def test_metadata_generation_and_columns(self, synthetic_ludb_root: Path): + ds = LUDBDataset(root=str(synthetic_ludb_root), dev=False, num_workers=1) + metadata_path = synthetic_ludb_root / "ludb-pyhealth.csv" + assert metadata_path.exists() + + df = pd.read_csv(metadata_path) + assert set( + ["patient_id", "visit_id", "record_id", "signal_file", "fs", "n_samples"] + ).issubset(df.columns) + for lead in LEADS: + assert f"lead_{lead}" in df.columns + + first = df.iloc[0] + assert Path(str(first["signal_file"])).is_absolute() + assert int(first["fs"]) == LUDB_FS + assert int(first["n_samples"]) == N_SAMPLES + + # event parsing/data integrity + patient = ds.get_patient("1") + events = cast(list[Event], patient.get_events("ludb", return_df=False)) + assert len(events) == 1 + ev = events[0] + assert ev.visit_id == "ecg" + assert ev.record_id == "1" + assert Path(ev.signal_file).is_absolute() + assert ev.lead_i.endswith(".i") + + def test_task_integration_with_set_task( + self, synthetic_ludb_root: Path, tmp_path: Path + ): + ds = LUDBDataset( + root=str(synthetic_ludb_root), + dev=False, + num_workers=1, + cache_dir=str(tmp_path / "cache"), + ) + task = ECGDelineationLUDB(split_by_pulse=False) + sample_ds = ds.set_task(task=task, num_workers=1) + assert len(sample_ds) > 0 + + sample = sample_ds[0] + assert "signal" in sample and "mask" in sample and "label" in sample + assert isinstance(sample["signal"], torch.Tensor) + assert isinstance(sample["mask"], torch.Tensor) + assert sample["signal"].ndim == 2 # [C, T] + assert sample["mask"].ndim == 1 + + def test_strict_root_resolution_raises_on_empty_root(self, tmp_path: Path): + empty_root = tmp_path / "empty_ludb" + empty_root.mkdir(parents=True, exist_ok=True) + + with pytest.raises(FileNotFoundError): + LUDBDataset(root=str(empty_root), dev=True, num_workers=1) + + with pytest.raises(FileNotFoundError): + LUDBDataset(root=str(empty_root), dev=False, num_workers=1) + + def test_stats_and_info_smoke(self, synthetic_ludb_root: Path): + ds = LUDBDataset(root=str(synthetic_ludb_root), dev=True, num_workers=1) + ds.stats() + ds.info() + + +class TestStratifiedSplit: + def test_split_sizes_disjoint_complete_and_deterministic( + self, strat_csv_path: Path + ): + train_ids, val_ids, test_ids = get_stratified_ludb_split( + str(strat_csv_path), train_ratio=0.8, val_ratio=0.1, seed=42 + ) + assert len(train_ids) == 16 + assert len(val_ids) == 2 + assert len(test_ids) == 2 + + all_ids = set(train_ids) | set(val_ids) | set(test_ids) + assert all_ids == set(range(1, 21)) + assert set(train_ids).isdisjoint(val_ids) + assert set(train_ids).isdisjoint(test_ids) + assert set(val_ids).isdisjoint(test_ids) + + split1 = get_stratified_ludb_split(str(strat_csv_path), seed=123) + split2 = get_stratified_ludb_split(str(strat_csv_path), seed=123) + assert split1 == split2