diff --git a/.gitignore b/.gitignore index 9993737db..d7010da78 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,5 @@ data/physionet.org/ .vscode/ # Model weight files (large binaries, distributed separately) -weightfiles/ \ No newline at end of file +weightfiles/examples/run_h1_h4_eicu.py +examples/run_h1_h4_eicu.py diff --git a/README_TPC_LOS.md b/README_TPC_LOS.md new file mode 100644 index 000000000..b6d42beb7 --- /dev/null +++ b/README_TPC_LOS.md @@ -0,0 +1,215 @@ +# CS598 Deep Learning for Healthcare projet + +This PR reproduces and contributes an implementation of: + +**Temporal Pointwise Convolutional Networks for Length of Stay Prediction in the Intensive Care Unit +Rocheteau et al., ACM CHIL 2021** +Paper: https://arxiv.org/abs/2007.09483 + + +## Contributors + +* Michael Edukonis (`meduk2@illinois.edu`) +* Keon Young Lee (`kylee7@illinois.edu`) +* Tanmay Thareja (`tanmayt3@illinois.edu`) + +# PR Overview + +## Contribution Types + +This PR includes: + +- Model contribution +- Standalone Task contribution +- Synthetic data tests +- End to End Pipeline (Model + Task + Dataset configs) example scripts + combined dual dataset run + +## Problem Overview + + +Efficient ICU bed management depends critically on estimating how long a patient will remain in the ICU. + +This is formulated as: + + • Input: Patient data up to hour _t_ + • Output: Remaining ICU length of stay _(LoS)_ at time _t_ + +We follow the formulation in the original paper, predicting remaining LoS at each hour of the ICU stay. + +## Implementation Details + +**1) Model Contribution** + +_pyhealth.models.tpc_ + +We implement the Temporal Pointwise Convolution (TPC) model as a PyHealth-compatible model by extending BaseModel and follows the original paper’s architecture with adaptations for PyHealth’s input/output interfaces. Index files were updated to include the new modules accordingly. + +**2) Task Contribution** + +We implement a custom PyHealth task: Hourly Remaining Length of Stay. Index files were updated to include the new modules accordingly. + +_pyhealth.tasks.hourly_los_ + +Task Definition + + • Predict remaining LoS at every hour + • Predictions start after first 5 hours + • Continuous regression task + +Motivation + + • Mimics real-world ICU monitoring + • Enables dynamic prediction updates + +**3) Ablation Study/ Example Usage** + +We implemented scripts for runnning the pipeline end to end with support for different experiemental setups or ablations. + +_examples/eicu_hourly_los_tpc.py_ : This provides an end-to-end script for reproducing and evaluating pipeline on EICU dataset. +_examples/mimic4_hourly_los_tpc.py_ : This provides an end-to-end script for reproducing and evaluating pipeline on MIMIC-IV dataset. +_examples/run_dual_dataset_tpc.py_ : This utility scripts runs the pipeline on both datasets and produces a combined report. + +**4) Test Cases** + +We implemented fast performing test cases for our Model and Task contribution using Sythentic Data. + +_tests/models/test_tpc.py_ +_tests/tasks/test_hourly_los.py_ + + +## Experimental Setup and Findings + +1) Ablations to include/exclude domain specific engineering (skip connections, decay indicators, etc) +
+image +
+2) Comparison across using combined temporal and pointwise convolutions vs using either architecture alone. +
+image +
+3) Feature independant (no weight-sharing) vs weight-shared temporal convolutions. +
+image +
+4) Evaluating MSLE loss vs MSE loss for skewed LoS target regression. +
+image +
+## Testing model with varying hyperparameters. + +We varied key optimization and architectural hyperparameters (e.g. learning rate, dropout rate, etc) while keeping preprocessing and data splits fixed. + +image + + +## File Structure + +```text +. +├── pyhealth/models/tpc.py +├── pyhealth/tasks/hourly_los.py +├── pyhealth/datasets/configs/eicu_tpc.yaml +├── pyhealth/datasets/configs/mimic4_ehr_tpc.yaml +├── examples/eicu_hourly_los_tpc.py +├── examples/mimic4_hourly_los_tpc.py +├── examples/run_dual_dataset_tpc.py +├── tests/models/test_tpc.py +├── tests/tasks/test_hourly_los.py +├── docs/api/models.rst +├── docs/api/models/pyhealth.models.tpc.rst +├── docs/api/tasks.rst +├── docs/api/tasks/pyhealth.tasks.hourly_los.rst +└── README_TPC_LOS.md +``` + +## Setup + +1. Clone the repository. +2. Create and activate a python virtual environment. +3. Install project dependencies. +4. Set dataset paths with environment variables for the example scripts. + +## Quick Start (Synthetic Data) + +### eICU example + +```bash +EICU_ROOT=/path/to/synthetic/eicu/data \ +PYTHONPATH=. python3 examples/eicu_hourly_los_tpc.py \ + --epochs 1 \ + --batch_size 2 \ + --max_samples 8 \ + --model_variant full \ + --loss msle \ + --num_workers 1 +``` + +### MIMIC-IV example + +```bash +MIMIC4_ROOT=/path/to/synthetic/mimic4/data \ +PYTHONPATH=. python3 examples/mimic4_hourly_los_tpc.py \ + --epochs 1 \ + --batch_size 2 \ + --max_samples 8 \ + --loss msle \ + --num_workers 1 +``` + +### Combined dual-dataset run + +```bash +EICU_ROOT=/path/to/synthetic/eicu/data \ +MIMIC4_ROOT=/path/to/synthetic/mimic4/data \ +PYTHONPATH=. python3 examples/run_dual_dataset_tpc.py \ + --eicu_cache_dir /path/to/eicu/cache/location + --mimic_cache_dir /path/to/mimic/cache/location + --eicu_max_samples 8 \ + --mimic_max_samples 8 \ + --eicu_epochs 1 \ + --mimic_epochs 1 \ + --model_variant full \ + --loss msle \ + --num_workers 1 +``` + +## Notes on Real Data + +For full eICU or MIMIC-IV experiments, point `EICU_ROOT` or `MIMIC4_ROOT` to the real dataset locations and increase settings such as `--epochs`, `--batch_size`, and `--max_samples` as needed. + +## Tests + +Run the project-specific tests with: + +```bash +python3 -m pytest tests/models/test_tpc.py tests/tasks/test_hourly_los.py -q +``` + +## Documentation + +API documentation entries were added for: + +* `pyhealth.models.tpc` +* `pyhealth.tasks.hourly_los` + +## Output + +The example scripts print compact summary lines for quick validation: + +* `ABLATION_SUMMARY` for eICU +* `MIMIC_SUMMARY` for MIMIC-IV + +The dual-dataset runner parses both and prints a combined summary table. + +## Environment + +This project is designed to run within the PyHealth environment. + +Recommended setup: + +- Python 3.12 +- pyhealth (>= 2.0.0) +- torch +- standard scientific Python stack (numpy, pandas) + +Install PyHealth and dependencies following the main repository instructions. diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..6d4932477 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -206,3 +206,4 @@ API Reference models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest + models/pyhealth.models.tpc diff --git a/docs/api/models/pyhealth.models.tpc.rst b/docs/api/models/pyhealth.models.tpc.rst new file mode 100644 index 000000000..a16dfea7e --- /dev/null +++ b/docs/api/models/pyhealth.models.tpc.rst @@ -0,0 +1,7 @@ +TPC Model +========= + +.. automodule:: pyhealth.models.tpc + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..61ee76504 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -213,6 +213,7 @@ Available Tasks DKA Prediction (MIMIC-IV) Drug Recommendation Length of Stay Prediction + Hourly Length-of-Stay (TPC) Medical Transcriptions Classification Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.hourly_los.rst b/docs/api/tasks/pyhealth.tasks.hourly_los.rst new file mode 100644 index 000000000..7ac100c10 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.hourly_los.rst @@ -0,0 +1,7 @@ + Hourly Length-of-Stay Task +========================== + +.. automodule:: pyhealth.tasks.hourly_los + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/eicu_hourly_los_tpc.py b/examples/eicu_hourly_los_tpc.py new file mode 100644 index 000000000..9fd6648b8 --- /dev/null +++ b/examples/eicu_hourly_los_tpc.py @@ -0,0 +1,522 @@ +""" +eICU hourly remaining length-of-stay (LoS) training and ablation script +for the Temporal Pointwise Convolution (TPC) model. + +This script runs the TPC model through the true PyHealth ``BaseModel`` +pipeline: + + 1. Load an eICU base dataset with the custom YAML config. + 2. Convert it to a task-specific ``SampleDataset`` using ``HourlyLOSEICU``. + 3. Split the task dataset by patient. + 4. Create dataloaders with PyHealth's ``get_dataloader``. + 5. Instantiate ``TPC(dataset=task_dataset, ...)``. + 6. Train using ``pyhealth.trainer.Trainer``. + 7. Evaluate scalar regression loss, MAE, and RMSE. + +This file is intentionally designed to work with synthetic, dev, or real +eICU roots. For project verification and fast iteration, it should be run +against synthetic data using small sample caps. + +Example: + >>> EICU_ROOT=/path/to/synthetic/eicu_demo PYTHONPATH=. \\ + ... python3 examples/eicu_hourly_los_tpc.py \\ + ... --epochs 1 \\ + ... --batch_size 2 \\ + ... --max_samples 24 +""" + +from __future__ import annotations + +import argparse +import os +import random +import sys +from typing import Dict + +import numpy as np +import torch +from pyhealth.datasets import SampleDataset + +DEFAULT_EICU_ROOT = os.environ.get("EICU_ROOT", "YOUR_EICU_DATASET_PATH") + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from pyhealth.datasets import eICUDataset, get_dataloader, split_by_patient +from pyhealth.models.tpc import TPC +from pyhealth.tasks.hourly_los import HourlyLOSEICU +from pyhealth.trainer import Trainer + + +def set_seed(seed: int = 42) -> None: + """Set deterministic seeds for Python, NumPy, and PyTorch. + + Args: + seed: Seed value applied to Python, NumPy, and PyTorch. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + + + + +def infer_model_dims(task_dataset) -> tuple[int, int]: + """Infer ``input_dim`` and ``static_dim`` from the processed task dataset. + + Args: + task_dataset: PyHealth ``SampleDataset`` returned by ``set_task()``. + + Returns: + A tuple ``(input_dim, static_dim)``. + + Raises: + ValueError: If the first sample does not contain a valid + ``time_series`` or ``static`` field. + """ + first_sample = task_dataset[0] + + if "time_series" not in first_sample: + raise ValueError("Task sample is missing required field 'time_series'") + if "static" not in first_sample: + raise ValueError("Task sample is missing required field 'static'") + + time_series = first_sample["time_series"] + static = first_sample["static"] + + if not isinstance(time_series, torch.Tensor): + time_series = torch.tensor(time_series, dtype=torch.float32) + if not isinstance(static, torch.Tensor): + static = torch.tensor(static, dtype=torch.float32) + + if time_series.ndim != 2: + raise ValueError( + "Expected task sample 'time_series' to have shape [T, 3F], got " + f"{tuple(time_series.shape)}" + ) + + feature_dim = time_series.shape[1] + if feature_dim % 3 != 0: + raise ValueError( + "Expected 'time_series' last dimension divisible by 3 for " + f"[value, mask, decay], got {feature_dim}" + ) + + input_dim = feature_dim // 3 + static_dim = int(static.shape[0]) + + return input_dim, static_dim + + +def evaluate_regression(model: TPC, dataloader) -> Dict[str, float]: + """Evaluate scalar regression loss, MAE, and RMSE. + + Args: + model: Trained TPC model. + dataloader: Evaluation dataloader. + + Returns: + Dictionary containing ``loss``, ``mae``, and ``rmse``. + """ + model.eval() + + total_loss = 0.0 + total_abs_error = 0.0 + total_sq_error = 0.0 + total_count = 0 + + with torch.no_grad(): + for batch in dataloader: + outputs = model(**batch) + + batch_loss = outputs["loss"].item() + y_pred = outputs["y_prob"].detach().cpu().view(-1) + y_true = outputs["y_true"].detach().cpu().view(-1) + + total_loss += batch_loss + total_abs_error += torch.sum(torch.abs(y_pred - y_true)).item() + total_sq_error += torch.sum((y_pred - y_true) ** 2).item() + total_count += int(y_true.numel()) + + mean_loss = total_loss / max(len(dataloader), 1) + mae = total_abs_error / max(total_count, 1) + rmse = (total_sq_error / max(total_count, 1)) ** 0.5 + + return { + "loss": mean_loss, + "mae": mae, + "rmse": rmse, + } + + +def parse_args(): + """Parse command-line arguments for the eICU TPC training script. + + Returns: + Parsed argparse namespace. + """ + parser = argparse.ArgumentParser( + description="Run eICU hourly remaining LoS prediction with TPC." + ) + parser.add_argument( + "--root", + type=str, + default=DEFAULT_EICU_ROOT, + help=( + "Path to eICU dataset root. Defaults to the EICU_ROOT " + "environment variable when set." + ), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="Path to the PyHealth cache directory." + ) + parser.add_argument("--dev", action="store_true", help="Use dev mode dataset") + parser.add_argument("--epochs", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument( + "--max_samples", + type=int, + default=128, + help=( + "Approximate total sample cap across train/val/test splits for " + "fast synthetic or smoke-style runs." + ), + ) + parser.add_argument("--num_layers", type=int, default=2) + parser.add_argument("--temporal_channels", type=int, default=4) + parser.add_argument("--pointwise_channels", type=int, default=4) + parser.add_argument("--kernel_size", type=int, default=3) + parser.add_argument("--fc_dim", type=int, default=16) + parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument( + "--loss", + type=str, + choices=["msle", "mse"], + default="msle", + help="Loss function for H3 ablation (default: msle).", + ) + parser.add_argument( + "--model_variant", + type=str, + choices=["full", "temporal_only", "pointwise_only"], + default="full", + help="Architectural ablation: full TPC, temporal only, or pointwise only.", + ) + parser.add_argument( + "--shared_temporal", + action="store_true", + help="Use shared temporal convolution weights across features.", + ) + parser.add_argument( + "--no_skip_connections", + action="store_true", + help="Disable concatenative skip connections.", + ) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="Number of CPU workers for task transformations" + ) + return parser.parse_args() + + +def main(): + """Run the full eICU hourly LoS training and evaluation pipeline.""" + args = parse_args() + set_seed(args.seed) + + print("=" * 80) + print("TPC eICU Hourly LoS Run Configuration") + print("=" * 80) + print(f"root: {args.root}") + print(f"dev: {args.dev}") + print(f"epochs: {args.epochs}") + print(f"batch_size: {args.batch_size}") + print(f"num_workers: {args.num_workers}") + print(f"max_samples: {args.max_samples}") + print(f"num_layers: {args.num_layers}") + print(f"temporal_channels: {args.temporal_channels}") + print(f"pointwise_channels: {args.pointwise_channels}") + print(f"kernel_size: {args.kernel_size}") + print(f"fc_dim: {args.fc_dim}") + print(f"dropout: {args.dropout}") + print(f"lr: {args.lr}") + print(f"loss: {args.loss}") + print(f"model_variant: {args.model_variant}") + print(f"shared_temporal: {args.shared_temporal}") + print(f"no_skip_connections: {args.no_skip_connections}") + print(f"seed: {args.seed}") + print("=" * 80) + + base_dataset = eICUDataset( + root=args.root, + tables=[ + "patient", + "lab", + "respiratorycharting", + "nursecharting", + "vitalperiodic", + "vitalaperiodic", + "pasthistory", + "admissiondx", + "diagnosis", + ], + dev=args.dev, + cache_dir=args.cache_dir, + config_path=os.path.join( + REPO_ROOT, + "pyhealth/datasets/configs/eicu_tpc.yaml", + ), + ) + + task_dataset = base_dataset.set_task( + HourlyLOSEICU( + time_series_tables=[ + "lab", + "respiratorycharting", + "nursecharting", + "vitalperiodic", + "vitalaperiodic", + ], + time_series_features={ + "lab": [ + "-basos", + "-eos", + "-lymphs", + "-monos", + "-polys", + "ALT (SGPT)", + "AST (SGOT)", + "BUN", + "Base Excess", + "FiO2", + "HCO3", + "Hct", + "Hgb", + "MCH", + "MCHC", + "MCV", + "MPV", + "O2 Sat (%)", + "PT", + "PT - INR", + "PTT", + "RBC", + "RDW", + "WBC x 1000", + "albumin", + "alkaline phos.", + "anion gap", + "bedside glucose", + "bicarbonate", + "calcium", + "chloride", + "creatinine", + "glucose", + "lactate", + "magnesium", + "pH", + "paCO2", + "paO2", + "phosphate", + "platelets x 1000", + "potassium", + "sodium", + "total bilirubin", + "total protein", + "troponin - I", + "urinary specific gravity", + ], + "respiratorycharting": [ + "Exhaled MV", + "Exhaled TV (patient)", + "LPM O2", + "Mean Airway Pressure", + "Peak Insp. Pressure", + "PEEP", + "Plateau Pressure", + "Pressure Support", + "RR (patient)", + "SaO2", + "TV/kg IBW", + "Tidal Volume (set)", + "Total RR", + "Vent Rate", + ], + "nursecharting": [ + "Bedside Glucose", + "Delirium Scale/Score", + "Glasgow coma score", + "Heart Rate", + "Invasive BP", + "Non-Invasive BP", + "O2 Admin Device", + "O2 L/%", + "O2 Saturation", + "Pain Score/Goal", + "Respiratory Rate", + "Sedation Score/Goal", + "Temperature", + ], + "vitalperiodic": [ + "cvp", + "heartrate", + "respiration", + "sao2", + "st1", + "st2", + "st3", + "systemicdiastolic", + "systemicmean", + "systemicsystolic", + "temperature", + ], + "vitalaperiodic": [ + "noninvasivediastolic", + "noninvasivemean", + "noninvasivesystolic", + ], + }, + numeric_static_features=[ + "age", + "admissionheight", + "admissionweight", + ], + # Keep the task contract stable for true BaseModel training. + categorical_static_features=[], + diagnosis_tables=[], + include_diagnoses=False, + diagnosis_time_limit_hours=5, + min_history_hours=5, + max_hours=48, + ), + num_workers=args.num_workers, + ) + + if len(task_dataset) == 0: + raise RuntimeError( + "No samples were generated by HourlyLOSEICU. " + "Check synthetic data, feature mappings, joins, cache, or dataset mode." + ) + + input_dim, static_dim = infer_model_dims(task_dataset) + + print(f"task samples: {len(task_dataset)}") + print(f"model input_dim: {input_dim}") + print(f"model static_dim: {static_dim}") + + # For very small synthetic datasets, use sample-level split instead of patient-level + + # Use patient-level split when possible. For very small synthetic datasets, + # reuse one non-empty split for validation so Trainer can run. + num_samples = len(task_dataset) + + if num_samples < 20: + train_ds, _, test_ds = split_by_patient(task_dataset, [0.5, 0.0, 0.5]) + val_ds = test_ds + else: + train_ds, val_ds, test_ds = split_by_patient(task_dataset, [0.70, 0.15, 0.15]) + + if len(train_ds) == 0 or len(val_ds) == 0 or len(test_ds) == 0: + raise RuntimeError( + "Dataset split produced an empty train, validation, or test split." + ) + + print(f"train samples: {len(train_ds)}") + print(f"val samples: {len(val_ds)}") + print(f"test samples: {len(test_ds)}") + + train_loader = get_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + #num_workers=args.num_workers, + ) + val_loader = get_dataloader(val_ds, + batch_size=args.batch_size, + shuffle=False, + #num_workers=args.num_workers, + ) + test_loader = get_dataloader(test_ds, + batch_size=args.batch_size, + shuffle=False, + #num_workers=args.num_workers, + ) + use_temporal = args.model_variant in {"full", "temporal_only"} + use_pointwise = args.model_variant in {"full", "pointwise_only"} + use_skip_connections = not args.no_skip_connections + + model = TPC( + dataset=task_dataset, + input_dim=input_dim, + static_dim=static_dim, + temporal_channels=args.temporal_channels, + pointwise_channels=args.pointwise_channels, + num_layers=args.num_layers, + kernel_size=args.kernel_size, + fc_dim=args.fc_dim, + dropout=args.dropout, + shared_temporal=args.shared_temporal, + use_temporal=use_temporal, + use_pointwise=use_pointwise, + use_skip_connections=use_skip_connections, + loss_name=args.loss, + ) + + trainer = Trainer( + model=model, + device="cpu", + enable_logging=False, + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + epochs=args.epochs, + optimizer_params={"lr": args.lr}, + ) + + val_results = evaluate_regression(model, val_loader) + test_results = evaluate_regression(model, test_loader) + + print("=" * 80) + print("Run complete") + print(f"model_variant: {args.model_variant}") + print(f"shared_temporal: {args.shared_temporal}") + print(f"no_skip_connections: {args.no_skip_connections}") + print(f"loss: {args.loss}") + print(f"val_loss: {val_results['loss']:.4f}") + print(f"val_mae: {val_results['mae']:.4f}") + print(f"val_rmse: {val_results['rmse']:.4f}") + print(f"test_loss: {test_results['loss']:.4f}") + print(f"test_mae: {test_results['mae']:.4f}") + print(f"test_rmse: {test_results['rmse']:.4f}") + print( + "ABLATION_SUMMARY " + f"model_variant={args.model_variant} " + f"shared_temporal={args.shared_temporal} " + f"no_skip_connections={args.no_skip_connections} " + f"loss={args.loss} " + f"val_loss={val_results['loss']:.4f} " + f"val_mae={val_results['mae']:.4f} " + f"val_rmse={val_results['rmse']:.4f} " + f"test_loss={test_results['loss']:.4f} " + f"test_mae={test_results['mae']:.4f} " + f"test_rmse={test_results['rmse']:.4f}" + ) + print("=" * 80) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mimic4_hourly_los_tpc.py b/examples/mimic4_hourly_los_tpc.py new file mode 100644 index 000000000..84cb66cb1 --- /dev/null +++ b/examples/mimic4_hourly_los_tpc.py @@ -0,0 +1,475 @@ +""" +MIMIC-IV hourly remaining length-of-stay (LoS) training and evaluation script +for the Temporal Pointwise Convolution (TPC) model. + +This script runs the TPC model through the true PyHealth ``BaseModel`` +pipeline: + + 1. Load a MIMIC-IV base dataset with the custom YAML config. + 2. Convert it to a task-specific ``SampleDataset`` using ``HourlyLOSEICU``. + 3. Split the task dataset by patient. + 4. Create dataloaders with PyHealth's ``get_dataloader``. + 5. Instantiate ``TPC(dataset=task_dataset, ...)``. + 6. Train using ``pyhealth.trainer.Trainer``. + 7. Evaluate scalar regression loss, MAE, and RMSE. + +This file is intended to work with synthetic, dev, or real MIMIC-IV roots. +For project verification and fast iteration, it should be run against +synthetic data using a small sample cap. + +Example: + >>> MIMIC4_ROOT=/path/to/synthetic/mimic4_demo PYTHONPATH=. \\ + ... python3 examples/mimic4_hourly_los_tpc.py \\ + ... --epochs 1 \\ + ... --batch_size 2 \\ + ... --max_samples 24 +""" + +from __future__ import annotations + +import argparse +import os +import random +import sys +from typing import Dict + +import numpy as np +import torch + +DEFAULT_MIMIC4_ROOT = os.environ.get("MIMIC4_ROOT", "YOUR_MIMIC4_DATASET_PATH") + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient +from pyhealth.models.tpc import TPC +from pyhealth.tasks.hourly_los import HourlyLOSEICU +from pyhealth.trainer import Trainer + + +def set_seed(seed: int = 42) -> None: + """Set deterministic seeds for Python, NumPy, and PyTorch. + + Args: + seed: Seed value applied to Python, NumPy, and PyTorch. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def infer_model_dims(task_dataset) -> tuple[int, int]: + """Infer ``input_dim`` and ``static_dim`` from the processed task dataset. + + Args: + task_dataset: PyHealth ``SampleDataset`` returned by ``set_task()``. + + Returns: + A tuple ``(input_dim, static_dim)``. + + Raises: + ValueError: If the first sample does not contain a valid + ``time_series`` or ``static`` field. + """ + first_sample = task_dataset[0] + + if "time_series" not in first_sample: + raise ValueError("Task sample is missing required field 'time_series'") + if "static" not in first_sample: + raise ValueError("Task sample is missing required field 'static'") + + time_series = first_sample["time_series"] + static = first_sample["static"] + + if not isinstance(time_series, torch.Tensor): + time_series = torch.tensor(time_series, dtype=torch.float32) + if not isinstance(static, torch.Tensor): + static = torch.tensor(static, dtype=torch.float32) + + if time_series.ndim != 2: + raise ValueError( + "Expected task sample 'time_series' to have shape [T, 3F], got " + f"{tuple(time_series.shape)}" + ) + + feature_dim = time_series.shape[1] + if feature_dim % 3 != 0: + raise ValueError( + "Expected 'time_series' last dimension divisible by 3 for " + f"[value, mask, decay], got {feature_dim}" + ) + + input_dim = feature_dim // 3 + static_dim = int(static.shape[0]) + + return input_dim, static_dim + + +def evaluate_regression(model: TPC, dataloader) -> Dict[str, float]: + """Evaluate scalar regression loss, MAE, and RMSE. + + Args: + model: Trained TPC model. + dataloader: Evaluation dataloader. + + Returns: + Dictionary containing ``loss``, ``mae``, and ``rmse``. + """ + model.eval() + + total_loss = 0.0 + total_abs_error = 0.0 + total_sq_error = 0.0 + total_count = 0 + + with torch.no_grad(): + for batch in dataloader: + outputs = model(**batch) + + batch_loss = outputs["loss"].item() + y_pred = outputs["y_prob"].detach().cpu().view(-1) + y_true = outputs["y_true"].detach().cpu().view(-1) + + total_loss += batch_loss + total_abs_error += torch.sum(torch.abs(y_pred - y_true)).item() + total_sq_error += torch.sum((y_pred - y_true) ** 2).item() + total_count += int(y_true.numel()) + + mean_loss = total_loss / max(len(dataloader), 1) + mae = total_abs_error / max(total_count, 1) + rmse = (total_sq_error / max(total_count, 1)) ** 0.5 + + return { + "loss": mean_loss, + "mae": mae, + "rmse": rmse, + } + + +def parse_args(): + """Parse command-line arguments for the MIMIC-IV TPC training script. + + Returns: + Parsed argparse namespace. + """ + parser = argparse.ArgumentParser( + description="Run MIMIC-IV hourly remaining LoS prediction with TPC." + ) + parser.add_argument( + "--root", + type=str, + default=DEFAULT_MIMIC4_ROOT, + help=( + "Path to the MIMIC-IV dataset root (directory containing hosp/ and " + "icu/). Defaults to the MIMIC4_ROOT environment variable when set." + ), + ) + parser.add_argument("--dev", action="store_true", help="Use dev mode dataset") + parser.add_argument("--epochs", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument( + "--max_samples", + type=int, + default=128, + help=( + "Approximate total sample cap across train/val/test splits for " + "fast synthetic or smoke-style runs." + ), + ) + parser.add_argument("--num_layers", type=int, default=2) + parser.add_argument("--temporal_channels", type=int, default=4) + parser.add_argument("--pointwise_channels", type=int, default=4) + parser.add_argument("--kernel_size", type=int, default=3) + parser.add_argument("--fc_dim", type=int, default=16) + parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument( + "--loss", + type=str, + choices=["msle", "mse"], + default="msle", + help="Loss function for H3 comparison (default: msle).", + ) + parser.add_argument( + "--max_hours", + type=int, + default=336, + help="Maximum ICU hours to model (14 days = 336 hours).", + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="Number of CPU workers for task transformations" +) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="Path to the PyHealth cache directory.", +) + return parser.parse_args() + + +def main(): + """Run the full MIMIC-IV hourly LoS training and evaluation pipeline.""" + args = parse_args() + set_seed(args.seed) + + print("=" * 80) + print("TPC MIMIC-IV Hourly LoS Run Configuration") + print("=" * 80) + print(f"root: {args.root}") + print(f"dev: {args.dev}") + print(f"epochs: {args.epochs}") + print(f"batch_size: {args.batch_size}") + print(f"num_workers: {args.num_workers}") + print(f"max_samples: {args.max_samples}") + print(f"num_layers: {args.num_layers}") + print(f"temporal_channels: {args.temporal_channels}") + print(f"pointwise_channels: {args.pointwise_channels}") + print(f"kernel_size: {args.kernel_size}") + print(f"fc_dim: {args.fc_dim}") + print(f"dropout: {args.dropout}") + print(f"lr: {args.lr}") + print(f"loss: {args.loss}") + print(f"max_hours: {args.max_hours}") + print(f"seed: {args.seed}") + print("=" * 80) + + base_dataset = MIMIC4Dataset( + ehr_root=args.root, + ehr_tables=[ + "patients", + "admissions", + "icustays", + "chartevents", + "labevents", + ], + ehr_config_path=os.path.join( + REPO_ROOT, + "pyhealth/datasets/configs/mimic4_ehr_tpc.yaml", + ), + dev=args.dev, + cache_dir=args.cache_dir, # Added proper cache directory support + ) + + task_dataset = base_dataset.set_task( + HourlyLOSEICU( + time_series_tables=["chartevents", "labevents"], + time_series_features={ + "chartevents": [ + "Activity / Mobility (JH-HLM)", + "Apnea Interval", + "Arterial Blood Pressure Alarm - High", + "Arterial Blood Pressure Alarm - Low", + "Arterial Blood Pressure diastolic", + "Arterial Blood Pressure mean", + "Arterial Blood Pressure systolic", + "Braden Score", + "Current Dyspnea Assessment", + "Daily Weight", + "Expiratory Ratio", + "Fspn High", + "GCS - Eye Opening", + "GCS - Motor Response", + "GCS - Verbal Response", + "Glucose finger stick (range 70-100)", + "Heart Rate", + "Heart Rate Alarm - Low", + "Heart rate Alarm - High", + "Inspired O2 Fraction", + "Mean Airway Pressure", + "Minute Volume", + "Minute Volume Alarm - High", + "Minute Volume Alarm - Low", + "Non Invasive Blood Pressure diastolic", + "Non Invasive Blood Pressure mean", + "Non Invasive Blood Pressure systolic", + "Non-Invasive Blood Pressure Alarm - High", + "Non-Invasive Blood Pressure Alarm - Low", + "O2 Flow", + "O2 Saturation Pulseoxymetry Alarm - Low", + "O2 saturation pulseoxymetry", + "PEEP set", + "PSV Level", + "Pain Level", + "Pain Level Response", + "Paw High", + "Peak Insp. Pressure", + "Phosphorous", + "Plateau Pressure", + "Resp Alarm - High", + "Resp Alarm - Low", + "Respiratory Rate", + "Respiratory Rate (Set)", + "Respiratory Rate (Total)", + "Respiratory Rate (spontaneous)", + "Richmond-RAS Scale", + "Strength L Arm", + "Strength L Leg", + "Strength R Arm", + "Strength R Leg", + "Temperature Fahrenheit", + "Tidal Volume (observed)", + "Tidal Volume (set)", + "Tidal Volume (spontaneous)", + "Total PEEP Level", + "Ventilator Mode", + "Vti High", + ], + "labevents": [ + "Alanine Aminotransferase (ALT)", + "Alkaline Phosphatase", + "Anion Gap", + "Asparate Aminotransferase (AST)", + "Base Excess", + "Bicarbonate", + "Bilirubin, Total", + "Calcium, Total", + "Calculated Total CO2", + "Chloride", + "Creatinine", + "Free Calcium", + "Glucose", + "Hematocrit", + "Hematocrit, Calculated", + "Hemoglobin", + "INR(PT)", + "Lactate", + "MCH", + "MCHC", + "MCV", + "Magnesium", + "Oxygen Saturation", + "PT", + "PTT", + "Phosphate", + "Platelet Count", + "Potassium", + "Potassium, Whole Blood", + "RDW", + "RDW-SD", + "Red Blood Cells", + "Sodium", + "Sodium, Whole Blood", + "Temperature", + "Urea Nitrogen", + "White Blood Cells", + "pCO2", + "pH", + "pO2", + ], + }, + numeric_static_features=[], + categorical_static_features=[], + min_history_hours=5, + max_hours=args.max_hours, + ), + num_workers=args.num_workers, + ) + + if len(task_dataset) == 0: + raise RuntimeError( + "No samples were generated by HourlyLOSEICU. " + "Check synthetic data, feature mappings, joins, cache, or dataset mode." + ) + + input_dim, static_dim = infer_model_dims(task_dataset) + + print(f"task samples: {len(task_dataset)}") + print(f"model input_dim: {input_dim}") + print(f"model static_dim: {static_dim}") + + num_samples = len(task_dataset) + + if num_samples < 20: + train_ds, _, test_ds = split_by_patient(task_dataset, [0.5, 0.0, 0.5]) + val_ds = test_ds + else: + train_ds, val_ds, test_ds = split_by_patient(task_dataset, [0.70, 0.15, 0.15]) + + if len(train_ds) == 0 or len(val_ds) == 0 or len(test_ds) == 0: + raise RuntimeError( + "Dataset split produced an empty train, validation, or test split." + ) + + print(f"train samples: {len(train_ds)}") + print(f"val samples: {len(val_ds)}") + print(f"test samples: {len(test_ds)}") + + + train_loader = get_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + #num_workers=dl_workers, + ) + val_loader = get_dataloader(val_ds, + batch_size=args.batch_size, + shuffle=False, + #num_workers=dl_workers, + ) + test_loader = get_dataloader(test_ds, + batch_size=args.batch_size, + shuffle=False, + #num_workers=dl_workers, + ) + + model = TPC( + dataset=task_dataset, + input_dim=input_dim, + static_dim=static_dim, + temporal_channels=args.temporal_channels, + pointwise_channels=args.pointwise_channels, + num_layers=args.num_layers, + kernel_size=args.kernel_size, + fc_dim=args.fc_dim, + dropout=args.dropout, + loss_name=args.loss, + ) + + + trainer = Trainer( + model=model, + device="cuda" if torch.cuda.is_available() else "cpu", + enable_logging=False, + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + epochs=args.epochs, + optimizer_params={"lr": args.lr}, + ) + + val_results = evaluate_regression(model, val_loader) + test_results = evaluate_regression(model, test_loader) + + print("=" * 80) + print("Run complete") + print(f"loss: {args.loss}") + print(f"val_loss: {val_results['loss']:.4f}") + print(f"val_mae: {val_results['mae']:.4f}") + print(f"val_rmse: {val_results['rmse']:.4f}") + print(f"test_loss: {test_results['loss']:.4f}") + print(f"test_mae: {test_results['mae']:.4f}") + print(f"test_rmse: {test_results['rmse']:.4f}") + print( + "MIMIC_SUMMARY " + f"loss={args.loss} " + f"val_loss={val_results['loss']:.4f} " + f"val_mae={val_results['mae']:.4f} " + f"val_rmse={val_results['rmse']:.4f} " + f"test_loss={test_results['loss']:.4f} " + f"test_mae={test_results['mae']:.4f} " + f"test_rmse={test_results['rmse']:.4f}" + ) + print("=" * 80) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/run_dual_dataset_tpc.py b/examples/run_dual_dataset_tpc.py new file mode 100644 index 000000000..37f00009f --- /dev/null +++ b/examples/run_dual_dataset_tpc.py @@ -0,0 +1,338 @@ +""" +Run eICU and MIMIC-IV TPC example pipelines sequentially and summarize results. + +This utility script orchestrates back-to-back execution of the eICU and +MIMIC-IV hourly length-of-stay (LoS) Temporal Pointwise Convolution (TPC) +example scripts using the current BaseModel-compatible example interfaces. + +Overview: + The script performs the following steps: + + 1. Build subprocess commands for the eICU and MIMIC-IV example scripts. + 2. Run each script sequentially from the repository root. + 3. Stream stdout to the console in real time. + 4. Parse the emitted summary lines from each script. + 5. Print a combined summary table for quick comparison. + +Inputs: + Command-line arguments controlling: + - optional dataset roots + - dev mode + - batch size + - epochs + - sample counts + - whether either dataset run should be skipped + +Outputs: + - Real-time console output from each child script + - Parsed eICU and MIMIC-IV summary metrics + - Combined summary table printed to stdout + - Nonzero exit if one or more subprocess runs fail +""" + +from __future__ import annotations + +import argparse +import os +import re +import subprocess +import sys + + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + + +SUMMARY_PATTERNS = { + "eicu": re.compile(r"^ABLATION_SUMMARY\s+(.*)$"), + "mimic": re.compile(r"^MIMIC_SUMMARY\s+(.*)$"), +} + + +def parse_summary_fields(summary_body: str) -> Dict[str, str]: + """Parse whitespace-delimited key-value tokens from a summary line.""" + fields: Dict[str, str] = {} + for token in summary_body.strip().split(): + if "=" in token: + key, value = token.split("=", 1) + fields[key.strip()] = value.strip() + return fields + + +def run_command( + label: str, + cmd: List[str], + cwd: str, +) -> Tuple[int, List[str], Optional[str]]: + """Run a subprocess command and capture its printed summary line.""" + print("=" * 80) + print(f"Running {label}") + print("=" * 80) + print("Command:") + print(" ".join(cmd)) + print("-" * 80) + + process = subprocess.Popen( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + output_lines: List[str] = [] + summary_line: Optional[str] = None + + assert process.stdout is not None + for line in process.stdout: + print(line, end="") + output_lines.append(line.rstrip("\n")) + + pattern = SUMMARY_PATTERNS.get(label) + if pattern is not None: + match = pattern.match(line.strip()) + if match: + summary_line = match.group(1) + + return_code = process.wait() + + print("-" * 80) + print(f"{label} return code: {return_code}") + print("=" * 80) + + return return_code, output_lines, summary_line + + +def format_value(fields: Dict[str, str], key: str) -> str: + """Return a formatted field value or ``NA`` if the key is missing.""" + return fields.get(key, "NA") + + +def print_combined_summary( + eicu_fields: Optional[Dict[str, str]], + mimic_fields: Optional[Dict[str, str]], +) -> None: + """Print a combined cross-dataset summary report.""" + print("\n" + "=" * 80) + print("COMBINED SUMMARY") + print("=" * 80) + + if eicu_fields is None: + print("eICU summary: NOT FOUND") + else: + print("eICU summary:") + print( + " " + f"model_variant={format_value(eicu_fields, 'model_variant')} " + f"shared_temporal={format_value(eicu_fields, 'shared_temporal')} " + f"no_skip_connections={format_value(eicu_fields, 'no_skip_connections')} " + f"loss={format_value(eicu_fields, 'loss')} " + f"val_loss={format_value(eicu_fields, 'val_loss')} " + f"val_mae={format_value(eicu_fields, 'val_mae')} " + f"val_rmse={format_value(eicu_fields, 'val_rmse')} " + f"test_loss={format_value(eicu_fields, 'test_loss')} " + f"test_mae={format_value(eicu_fields, 'test_mae')} " + f"test_rmse={format_value(eicu_fields, 'test_rmse')}" + ) + + if mimic_fields is None: + print("MIMIC-IV summary: NOT FOUND") + else: + print("MIMIC-IV summary:") + print( + " " + f"loss={format_value(mimic_fields, 'loss')} " + f"val_loss={format_value(mimic_fields, 'val_loss')} " + f"val_mae={format_value(mimic_fields, 'val_mae')} " + f"val_rmse={format_value(mimic_fields, 'val_rmse')} " + f"test_loss={format_value(mimic_fields, 'test_loss')} " + f"test_mae={format_value(mimic_fields, 'test_mae')} " + f"test_rmse={format_value(mimic_fields, 'test_rmse')}" + ) + + if eicu_fields is not None and mimic_fields is not None: + print("\nCompact table:") + print( + f"{'dataset':<10} {'loss':<8} {'val_loss':<10} {'val_mae':<10} " + f"{'val_rmse':<10} {'test_loss':<10} {'test_mae':<10} {'test_rmse':<10}" + ) + print( + f"{'eICU':<10} " + f"{format_value(eicu_fields, 'loss'):<8} " + f"{format_value(eicu_fields, 'val_loss'):<10} " + f"{format_value(eicu_fields, 'val_mae'):<10} " + f"{format_value(eicu_fields, 'val_rmse'):<10} " + f"{format_value(eicu_fields, 'test_loss'):<10} " + f"{format_value(eicu_fields, 'test_mae'):<10} " + f"{format_value(eicu_fields, 'test_rmse'):<10}" + ) + print( + f"{'MIMIC-IV':<10} " + f"{format_value(mimic_fields, 'loss'):<8} " + f"{format_value(mimic_fields, 'val_loss'):<10} " + f"{format_value(mimic_fields, 'val_mae'):<10} " + f"{format_value(mimic_fields, 'val_rmse'):<10} " + f"{format_value(mimic_fields, 'test_loss'):<10} " + f"{format_value(mimic_fields, 'test_mae'):<10} " + f"{format_value(mimic_fields, 'test_rmse'):<10}" + ) + + print("=" * 80) + + +def build_eicu_command(args) -> List[str]: + """Build the subprocess command for the eICU example script.""" + cmd = [ + sys.executable, + "examples/eicu_hourly_los_tpc.py", + "--epochs", + str(args.eicu_epochs), + "--batch_size", + str(args.batch_size), + "--max_samples", + str(args.eicu_max_samples), + "--loss", + str(args.eicu_loss), + # Pass the parallel worker count to leverage the i9 + "--num_workers", + str(args.num_workers), + ] + + if args.eicu_root: + cmd.extend(["--root", args.eicu_root]) + + # Pass the external cache directory to avoid re-processing + if hasattr(args, 'eicu_cache_dir') and args.eicu_cache_dir: + cmd.extend(["--cache_dir", args.eicu_cache_dir]) + + if args.dev: + cmd.append("--dev") + + if args.eicu_model_variant: + cmd.extend(["--model_variant", args.eicu_model_variant]) + + if args.eicu_shared_temporal: + cmd.append("--shared_temporal") + + if args.eicu_no_skip_connections: + cmd.append("--no_skip_connections") + + return cmd + + +def build_mimic_command(args) -> List[str]: + """Build the subprocess command for the MIMIC-IV example script.""" + cmd = [ + sys.executable, + "examples/mimic4_hourly_los_tpc.py", + "--epochs", str(args.mimic_epochs), + "--batch_size", str(args.batch_size), + "--max_samples", str(args.mimic_max_samples), + "--loss", str(args.mimic_loss), + "--num_workers", str(args.num_workers), + ] + + if args.mimic_root: + cmd.extend(["--root", args.mimic_root]) + + # Pass the MIMIC-IV cache directory + if hasattr(args, 'mimic_cache_dir') and args.mimic_cache_dir: + cmd.extend(["--cache_dir", args.mimic_cache_dir]) + + if args.dev: + cmd.append("--dev") + + return cmd + + +def parse_args(): + """Parse command-line arguments for the dual-dataset run script.""" + parser = argparse.ArgumentParser( + description=( + "Run eICU and MIMIC-IV TPC examples sequentially and print a " + "combined summary." + ) + ) + parser.add_argument("--dev", action="store_true", help="Run both scripts in dev mode") + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--eicu_epochs", type=int, default=1) + parser.add_argument("--mimic_epochs", type=int, default=1) + parser.add_argument("--eicu_max_samples", type=int, default=24) + parser.add_argument("--mimic_max_samples", type=int, default=24) + parser.add_argument("--eicu_root", type=str, default="") + parser.add_argument("--mimic_root", type=str, default="") + parser.add_argument( + "--eicu_loss", + type=str, + choices=["msle", "mse"], + default="msle", + help="Loss to pass to the eICU script.", + ) + parser.add_argument( + "--mimic_loss", + type=str, + choices=["msle", "mse"], + default="msle", + help="Loss to pass to the MIMIC-IV script.", + ) + parser.add_argument( + "--eicu_model_variant", + type=str, + choices=["full", "temporal_only", "pointwise_only"], + default="full", + ) + parser.add_argument( + "--eicu_shared_temporal", + action="store_true", + help="Pass --shared_temporal to the eICU script.", + ) + parser.add_argument( + "--eicu_no_skip_connections", + action="store_true", + help="Pass --no_skip_connections to the eICU script.", + ) + parser.add_argument("--skip_eicu", action="store_true", help="Skip the eICU run") + parser.add_argument("--skip_mimic", action="store_true", help="Skip the MIMIC-IV run") + parser.add_argument("--num_workers", type=int, default=1, help="Parallel workers for both datasets.") + parser.add_argument("--eicu_cache_dir", type=str, default=None, help="Cache path for eICU.") + parser.add_argument("--mimic_cache_dir", type=str, default=None, help="Cache path for MIMIC-IV.") + return parser.parse_args() + + +def main(): + """Run the configured eICU and MIMIC-IV example scripts and summarize results.""" + args = parse_args() + + eicu_fields: Optional[Dict[str, str]] = None + mimic_fields: Optional[Dict[str, str]] = None + + eicu_rc = 0 + mimic_rc = 0 + + if not args.skip_eicu: + eicu_cmd = build_eicu_command(args) + eicu_rc, _, eicu_summary = run_command("eicu", eicu_cmd, cwd=REPO_ROOT) + if eicu_summary is not None: + eicu_fields = parse_summary_fields(eicu_summary) + + if not args.skip_mimic: + mimic_cmd = build_mimic_command(args) + mimic_rc, _, mimic_summary = run_command("mimic", mimic_cmd, cwd=REPO_ROOT) + if mimic_summary is not None: + mimic_fields = parse_summary_fields(mimic_summary) + + print_combined_summary(eicu_fields, mimic_fields) + + if eicu_rc != 0 or mimic_rc != 0: + failed = [] + if eicu_rc != 0: + failed.append(f"eICU rc={eicu_rc}") + if mimic_rc != 0: + failed.append(f"MIMIC-IV rc={mimic_rc}") + raise SystemExit("One or more runs failed: " + ", ".join(failed)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/configs/eicu_tpc.yaml b/pyhealth/datasets/configs/eicu_tpc.yaml new file mode 100644 index 000000000..d0dba8b06 --- /dev/null +++ b/pyhealth/datasets/configs/eicu_tpc.yaml @@ -0,0 +1,362 @@ +# Custom eICU dataset configuration for the TPC hourly length-of-stay project. +# +# This configuration extends the standard PyHealth eICU dataset mapping to +# support the Temporal Pointwise Convolution (TPC) replication pipeline. +# It exposes the tables and attributes needed for hourly ICU remaining +# length-of-stay prediction, including time-series measurements, static +# patient features, and diagnosis/history sources used by the custom task. +# +# Intended use: +# - pyhealth.tasks.hourly_los.HourlyLOSEICU +# - examples/eicu_hourly_los_tpc.py +# +# Notes: +# - Table names are normalized to the lowercase names used by the custom +# task pipeline. +# - Several tables join through patient.csv on patientunitstayid in order +# to recover uniquepid and visit-level context for task construction. +# - This file is project-specific and should be kept aligned with the +# expected table/attribute names used in the eICU TPC task code. +version: "2.0" +tables: + patient: + file_path: "patient.csv" + patient_id: "uniquepid" + timestamp: null + attributes: + - "patientunitstayid" + - "patienthealthsystemstayid" + - "gender" + - "age" + - "ethnicity" + - "hospitalid" + - "wardid" + - "apacheadmissiondx" + - "admissionheight" + - "admissionweight" + - "dischargeweight" + - "hospitaladmittime24" + - "hospitaladmitoffset" + - "hospitaladmitsource" + - "hospitaldischargeyear" + - "hospitaldischargeoffset" + - "hospitaldischargestatus" + - "hospitaldischargetime24" + - "unittype" + - "unitadmittime24" + - "unitadmitsource" + - "unitvisitnumber" + - "unitstaytype" + - "unitdischargeoffset" + - "unitdischargestatus" + - "unitdischargetime24" + - "unitdischargelocation" + + hospital: + file_path: "hospital.csv" + patient_id: null + timestamp: null + attributes: + - "hospitalid" + - "numbedscategory" + - "teachingstatus" + - "region" + + diagnosis: + file_path: "diagnosis.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "diagnosisoffset" + - "diagnosisstring" + - "icd9code" + - "diagnosispriority" + + medication: + file_path: "medication.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "drugstartoffset" + - "drugstopoffset" + - "drugname" + - "drughiclseqno" + - "dosage" + - "routeadmin" + - "frequency" + - "loadingdose" + - "prn" + - "drugordercancelled" + + treatment: + file_path: "treatment.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "treatmentoffset" + - "treatmentstring" + - "activeupondischarge" + + lab: + file_path: "lab.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "labresultoffset" + - "labname" + - "labresult" + - "labresulttext" + - "labmeasurenamesystem" + - "labmeasurenameinterface" + - "labtypeid" + + physicalexam: + file_path: "physicalExam.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "physicalexamoffset" + - "physicalexampath" + - "physicalexamtext" + - "physicalexamvalue" + + admissiondx: + file_path: "admissionDx.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "admitdxenteredoffset" + - "admitdxpath" + - "admitdxname" + - "admitdxtext" + + respiratorycharting: + file_path: "respiratoryCharting.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "respchartoffset" + - "respchartentryoffset" + - "respcharttypecat" + - "respchartvaluelabel" + - "respchartvalue" + + nursecharting: + file_path: "nurseCharting.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "nursingchartoffset" + - "nursingchartentryoffset" + - "nursingchartcelltypecat" + - "nursingchartcelltypevallabel" + - "nursingchartcelltypevalname" + - "nursingchartvalue" + + vitalperiodic: + file_path: "vitalPeriodic.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "observationoffset" + - "temperature" + - "sao2" + - "heartrate" + - "respiration" + - "cvp" + - "etco2" + - "systemicsystolic" + - "systemicdiastolic" + - "systemicmean" + - "pasystolic" + - "padiastolic" + - "pamean" + - "st1" + - "st2" + - "st3" + - "icp" + + vitalaperiodic: + file_path: "vitalAperiodic.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "observationoffset" + - "noninvasivesystolic" + - "noninvasivediastolic" + - "noninvasivemean" + - "paop" + - "cardiacoutput" + - "cardiacinput" + - "svr" + - "svri" + - "pvr" + - "pvri" + + pasthistory: + file_path: "pastHistory.csv" + patient_id: "uniquepid" + timestamp: null + join: + - file_path: "patient.csv" + "on": "patientunitstayid" + how: "inner" + columns: + - "uniquepid" + - "patienthealthsystemstayid" + - "unitvisitnumber" + - "hospitaladmitoffset" + - "hospitaldischargeoffset" + - "unitdischargeoffset" + - "hospitaldischargeyear" + - "hospitaldischargestatus" + attributes: + - "patientunitstayid" + - "pasthistoryoffset" + - "pasthistoryenteredoffset" + - "pasthistorynotetype" + - "pasthistorypath" + - "pasthistoryvalue" + - "pasthistoryvaluetext" + + diff --git a/pyhealth/datasets/configs/mimic4_ehr_tpc.yaml b/pyhealth/datasets/configs/mimic4_ehr_tpc.yaml new file mode 100644 index 000000000..7bc0a1ea0 --- /dev/null +++ b/pyhealth/datasets/configs/mimic4_ehr_tpc.yaml @@ -0,0 +1,173 @@ +# Custom MIMIC-IV dataset configuration for the TPC hourly length-of-stay project. +# +# This configuration defines the table mappings and joins required to load +# MIMIC-IV data into PyHealth for the Temporal Pointwise Convolution (TPC) +# replication pipeline. It is tailored to support hourly ICU remaining +# length-of-stay prediction using time-series EHR data. +# +# Intended use: +# - pyhealth.tasks.hourly_los.HourlyLOSEICU (MIMIC-IV variant) +# - examples/mimic4_hourly_los_tpc.py +# +# Notes: +# - Core ICU time-series data is sourced from `chartevents`, joined with +# `d_items` to obtain human-readable labels and feature metadata. +# - Laboratory features are sourced from `labevents`, joined with +# `d_labitems` for standardized naming and categorization. +# - Static and contextual patient information is derived from `patients`, +# `admissions`, and `icustays`. +# - Diagnosis and procedure codes are included for optional feature +# engineering and cohort analysis. +# - Timestamps are defined per table to enable temporal alignment and +# hourly resampling within the task pipeline. +# - This configuration is project-specific and should remain aligned with +# the feature extraction logic implemented in the MIMIC-IV LoS task. +version: "2.2" +tables: + patients: + file_path: "hosp/patients.csv.gz" + patient_id: "subject_id" + timestamp: null + attributes: + - "gender" + - "anchor_age" + - "anchor_year" + - "anchor_year_group" + - "dod" + + chartevents: + file_path: "icu/chartevents.csv.gz" + patient_id: "subject_id" + join: + - file_path: "icu/d_items.csv.gz" + "on": "itemid" + how: "inner" + columns: + - "label" + - "abbreviation" + - "linksto" + - "category" + - "unitname" + - "param_type" + timestamp: "charttime" + attributes: + - "hadm_id" + - "stay_id" + - "itemid" + - "label" + - "abbreviation" + - "linksto" + - "category" + - "unitname" + - "param_type" + - "value" + - "valuenum" + - "valueuom" + - "storetime" + - "warning" + + admissions: + file_path: "hosp/admissions.csv.gz" + patient_id: "subject_id" + timestamp: "admittime" + attributes: + - "hadm_id" + - "admission_type" + - "admission_location" + - "insurance" + - "language" + - "marital_status" + - "race" + - "discharge_location" + - "dischtime" + - "hospital_expire_flag" + + icustays: + file_path: "icu/icustays.csv.gz" + patient_id: "subject_id" + timestamp: "intime" + attributes: + - "hadm_id" + - "stay_id" + - "first_careunit" + - "last_careunit" + - "outtime" + + diagnoses_icd: + file_path: "hosp/diagnoses_icd.csv.gz" + patient_id: "subject_id" + join: + - file_path: "hosp/admissions.csv.gz" + "on": "hadm_id" + how: "inner" + columns: + - "dischtime" + timestamp: "dischtime" + attributes: + - "hadm_id" + - "icd_code" + - "icd_version" + - "seq_num" + + procedures_icd: + file_path: "hosp/procedures_icd.csv.gz" + patient_id: "subject_id" + join: + - file_path: "hosp/admissions.csv.gz" + "on": "hadm_id" + how: "inner" + columns: + - "dischtime" + timestamp: "dischtime" + attributes: + - "hadm_id" + - "icd_code" + - "icd_version" + - "seq_num" + + prescriptions: + file_path: "hosp/prescriptions.csv.gz" + patient_id: "subject_id" + timestamp: "starttime" + attributes: + - "hadm_id" + - "drug" + - "ndc" + - "prod_strength" + - "dose_val_rx" + - "dose_unit_rx" + - "route" + - "stoptime" + + labevents: + file_path: "hosp/labevents.csv.gz" + patient_id: "subject_id" + join: + - file_path: "hosp/d_labitems.csv.gz" + "on": "itemid" + how: "inner" + columns: + - "label" + - "fluid" + - "category" + timestamp: "charttime" + attributes: + - "hadm_id" + - "itemid" + - "label" + - "fluid" + - "category" + - "value" + - "valuenum" + - "valueuom" + - "flag" + - "storetime" + + hcpcsevents: + file_path: "hosp/hcpcsevents.csv.gz" + patient_id: "subject_id" + timestamp: "chartdate" + attributes: + - "hcpcs_cd" + - "seq_num" + - "short_description" \ No newline at end of file diff --git a/pyhealth/models/tpc.py b/pyhealth/models/tpc.py new file mode 100644 index 000000000..547c3b255 --- /dev/null +++ b/pyhealth/models/tpc.py @@ -0,0 +1,810 @@ +""" +Temporal Pointwise Convolution (TPC) model for hourly ICU remaining +length-of-stay (LoS) regression in PyHealth. + +This module provides a dataset-backed PyHealth ``BaseModel`` implementation +of the Temporal Pointwise Convolution (TPC) architecture described in: + +Rocheteau, E., Liò, P., and Hyland, S. (2021). +"Temporal Pointwise Convolutional Networks for Length of Stay Prediction +in the Intensive Care Unit." + +Overview: + TPC is designed for multivariate, irregularly sampled EHR time series. + It combines two complementary operations at each layer: + + 1. Temporal convolution: + Feature-wise or shared causal convolutions over time, allowing + each clinical variable to learn temporal dynamics. + + 2. Pointwise convolution: + Per-time-step feature mixing to capture cross-feature interactions + without temporal leakage. + + In this PyHealth implementation, the task supplies: + - ``time_series``: per-sample hourly history encoded as [T, 3F] + in [value, mask, decay] order for each feature + - ``static``: optional static feature vector + - ``target_los_hours``: scalar regression target + + The model consumes task-processed batch fields via ``forward(**kwargs)`` + and returns the standard PyHealth output dictionary: + - ``loss`` + - ``y_prob`` + - ``y_true`` + - ``logit`` + +Key Components: + - ``TemporalConvBlock``: causal temporal convolution per feature + - ``PointwiseConvBlock``: per-time-step feature interaction layer + - ``TPCLayer``: combined temporal + pointwise block + - ``TPC``: full stacked regression model + +Implementation Notes: + - The public model contract is now fully dataset-backed through + ``BaseModel``. + - The model predicts one scalar remaining LoS value per sample. + - Nonlinearities use ``nn.Module`` variants to remain compatible with + PyHealth interpretability expectations. + - The input ``time_series`` field is internally split into value, mask, + and decay channels from a [T, 3F] representation. +""" + +from __future__ import annotations + +from typing import Any, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.models import BaseModel + + +class TemporalConvBlock(nn.Module): + """Feature-wise or shared causal temporal convolution block. + + This block applies a 1-D causal convolution over time for each feature. + It supports two modes: + + 1. Feature-wise temporal convolution: + A separate ``nn.Conv1d`` is created for each feature so temporal + filters are not shared across features. + + 2. Shared temporal convolution: + A single ``nn.Conv1d`` is reused for all features, allowing an + ablation against the feature-specific version. + + Input: + x: Tensor of shape ``[B, T, F, C_in]`` + + Output: + y: Tensor of shape ``[B, T, F, C_out]`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_features: int, + kernel_size: int, + dilation: int, + dropout: float = 0.0, + shared_temporal: bool = False, + ) -> None: + """Initialize the temporal convolution block. + + Args: + in_channels: Number of per-feature input channels. + out_channels: Number of per-feature output channels. + num_features: Number of time-series features. + kernel_size: Temporal convolution kernel size. + dilation: Temporal dilation factor. + dropout: Dropout probability applied after convolution. + shared_temporal: Whether to share one temporal convolution across + all features. + + Raises: + ValueError: If any required dimensional argument is invalid. + """ + super().__init__() + if in_channels <= 0: + raise ValueError("in_channels must be positive") + if out_channels <= 0: + raise ValueError("out_channels must be positive") + if num_features <= 0: + raise ValueError("num_features must be positive") + if kernel_size <= 0: + raise ValueError("kernel_size must be positive") + if dilation <= 0: + raise ValueError("dilation must be positive") + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_features = num_features + self.kernel_size = kernel_size + self.dilation = dilation + self.shared_temporal = shared_temporal + + if self.shared_temporal: + self.shared_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + bias=True, + ) + self.feature_convs = None + else: + self.shared_conv = None + self.feature_convs = nn.ModuleList( + [ + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + bias=True, + ) + for _ in range(num_features) + ] + ) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply causal temporal convolution to each feature. + + Args: + x: Input tensor of shape ``[B, T, F, C_in]``. + + Returns: + Output tensor of shape ``[B, T, F, C_out]``. + + Raises: + ValueError: If the input tensor has incompatible shape. + """ + if x.ndim != 4: + raise ValueError( + f"Expected x to have shape [B, T, F, C], got {tuple(x.shape)}" + ) + + _, _, num_features, in_channels = x.shape + if num_features != self.num_features: + raise ValueError( + f"Expected {self.num_features} features, got {num_features}" + ) + if in_channels != self.in_channels: + raise ValueError( + f"Expected {self.in_channels} input channels, got {in_channels}" + ) + + outputs = [] + left_pad = self.dilation * (self.kernel_size - 1) + + for feat_idx in range(self.num_features): + feat_x = x[:, :, feat_idx, :].transpose(1, 2) # [B, C_in, T] + feat_x = F.pad(feat_x, (left_pad, 0)) + + if self.shared_temporal: + feat_y = self.shared_conv(feat_x) + else: + feat_y = self.feature_convs[feat_idx](feat_x) + + feat_y = feat_y.transpose(1, 2) # [B, T, C_out] + outputs.append(feat_y.unsqueeze(2)) # [B, T, 1, C_out] + + y = torch.cat(outputs, dim=2) # [B, T, F, C_out] + y = self.dropout(y) + return y + + +class PointwiseConvBlock(nn.Module): + """Pointwise transformation applied independently at each time step. + + This block is implemented as a linear layer operating on the flattened + per-time-step representation. It is equivalent in spirit to a per-time-step + 1x1 convolution across the feature/channel dimension. + + Input: + x: Tensor of shape ``[B, T, D_in]`` + + Output: + y: Tensor of shape ``[B, T, D_out]`` + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + dropout: float = 0.0, + ) -> None: + """Initialize the pointwise block. + + Args: + input_dim: Input dimension per time step. + output_dim: Output dimension per time step. + dropout: Dropout probability applied after projection. + + Raises: + ValueError: If either dimension is not positive. + """ + super().__init__() + if input_dim <= 0: + raise ValueError("input_dim must be positive") + if output_dim <= 0: + raise ValueError("output_dim must be positive") + + self.linear = nn.Linear(input_dim, output_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the pointwise transformation. + + Args: + x: Input tensor of shape ``[B, T, D_in]``. + + Returns: + Output tensor of shape ``[B, T, D_out]``. + + Raises: + ValueError: If the input tensor shape is invalid. + """ + if x.ndim != 3: + raise ValueError( + f"Expected x to have shape [B, T, D], got {tuple(x.shape)}" + ) + y = self.linear(x) + y = self.dropout(y) + return y + + +class TPCLayer(nn.Module): + """One Temporal Pointwise Convolution layer. + + This layer combines: + - optional temporal convolution branch + - optional pointwise branch + - optional concatenative skip connections + + Input: + x: Tensor of shape ``[B, T, F, C_in]`` + decay: Tensor of shape ``[B, T, F]`` or ``None`` + static: Tensor of shape ``[B, S]`` or ``None`` + + Output: + fused: Tensor of shape ``[B, T, F, C_out]`` + """ + + def __init__( + self, + num_features: int, + in_channels: int, + temporal_channels: int, + pointwise_channels: int, + static_dim: int, + kernel_size: int, + dilation: int, + dropout: float = 0.0, + use_decay_in_pointwise: bool = True, + shared_temporal: bool = False, + use_temporal: bool = True, + use_pointwise: bool = True, + use_skip_connections: bool = True, + ) -> None: + """Initialize the TPC layer. + + Args: + num_features: Number of time-series features. + in_channels: Number of input channels per feature. + temporal_channels: Number of temporal branch output channels. + pointwise_channels: Number of pointwise branch output channels. + static_dim: Static feature dimension. + kernel_size: Temporal kernel size. + dilation: Temporal dilation factor. + dropout: Dropout probability. + use_decay_in_pointwise: Whether decay indicators are concatenated + into the pointwise branch input. + shared_temporal: Whether temporal filters are shared across + features. + use_temporal: Whether to enable the temporal branch. + use_pointwise: Whether to enable the pointwise branch. + use_skip_connections: Whether to concatenate the prior input + representation into the layer output. + + Raises: + ValueError: If layer dimensions are invalid or both branches + are disabled. + """ + super().__init__() + if num_features <= 0: + raise ValueError("num_features must be positive") + if in_channels <= 0: + raise ValueError("in_channels must be positive") + if temporal_channels <= 0: + raise ValueError("temporal_channels must be positive") + if pointwise_channels <= 0: + raise ValueError("pointwise_channels must be positive") + if static_dim < 0: + raise ValueError("static_dim must be non-negative") + + self.num_features = num_features + self.in_channels = in_channels + self.temporal_channels = temporal_channels + self.pointwise_channels = pointwise_channels + self.static_dim = static_dim + self.use_decay_in_pointwise = use_decay_in_pointwise + self.shared_temporal = shared_temporal + self.use_temporal = use_temporal + self.use_pointwise = use_pointwise + self.use_skip_connections = use_skip_connections + + if not self.use_temporal and not self.use_pointwise: + raise ValueError( + "At least one of use_temporal or use_pointwise must be True" + ) + + self.output_channels = 0 + if self.use_skip_connections: + self.output_channels += in_channels + if self.use_temporal: + self.output_channels += temporal_channels + if self.use_pointwise: + self.output_channels += pointwise_channels + + if self.use_temporal: + self.temporal = TemporalConvBlock( + in_channels=in_channels, + out_channels=temporal_channels, + num_features=num_features, + kernel_size=kernel_size, + dilation=dilation, + dropout=dropout, + shared_temporal=shared_temporal, + ) + else: + self.temporal = None + + if self.use_pointwise: + point_input_dim = (num_features * in_channels) + static_dim + if use_decay_in_pointwise: + point_input_dim += num_features + + self.pointwise = PointwiseConvBlock( + input_dim=point_input_dim, + output_dim=pointwise_channels, + dropout=dropout, + ) + else: + self.pointwise = None + + self.activation = nn.ReLU() + + def forward( + self, + x: torch.Tensor, + decay: Optional[torch.Tensor] = None, + static: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the TPC layer. + + Args: + x: Input tensor of shape ``[B, T, F, C_in]``. + decay: Optional decay tensor of shape ``[B, T, F]``. + static: Optional static tensor of shape ``[B, S]``. + + Returns: + Fused tensor of shape ``[B, T, F, C_out]``. + + Raises: + ValueError: If tensor shapes are incompatible. + """ + if x.ndim != 4: + raise ValueError( + f"Expected x to have shape [B, T, F, C], got {tuple(x.shape)}" + ) + + bsz, seq_len, num_features, in_channels = x.shape + if num_features != self.num_features: + raise ValueError( + f"Expected {self.num_features} features, got {num_features}" + ) + if in_channels != self.in_channels: + raise ValueError( + f"Expected {self.in_channels} input channels, got {in_channels}" + ) + + if decay is not None: + if decay.ndim != 3: + raise ValueError( + f"Expected decay to have shape [B, T, F], got {tuple(decay.shape)}" + ) + if decay.shape[:3] != (bsz, seq_len, num_features): + raise ValueError( + "Expected decay shape " + f"{(bsz, seq_len, num_features)}, got {tuple(decay.shape)}" + ) + + if static is not None: + if static.ndim != 2: + raise ValueError( + f"Expected static to have shape [B, S], got {tuple(static.shape)}" + ) + if static.shape[0] != bsz: + raise ValueError( + f"Expected static batch size {bsz}, got {static.shape[0]}" + ) + + parts_to_concat = [] + + if self.use_skip_connections: + parts_to_concat.append(x) + + if self.use_temporal: + temp_out = self.temporal(x) + parts_to_concat.append(temp_out) + + if self.use_pointwise: + flat_x = x.reshape(bsz, seq_len, num_features * in_channels) + point_parts = [flat_x] + + if static is not None: + static_rep = static.unsqueeze(1).expand(-1, seq_len, -1) + point_parts.append(static_rep) + + if self.use_decay_in_pointwise: + if decay is None: + raise ValueError( + "decay must be provided when use_decay_in_pointwise=True" + ) + point_parts.append(decay) + + point_in = torch.cat(point_parts, dim=-1) + point_out = self.pointwise(point_in) + point_broadcast = point_out.unsqueeze(2).expand( + -1, -1, num_features, -1 + ) + parts_to_concat.append(point_broadcast) + + fused = torch.cat(parts_to_concat, dim=-1) + fused = self.activation(fused) + return fused + + +class TPC(BaseModel): + """Temporal Pointwise Convolution model for scalar LoS regression. + + This is a true dataset-backed PyHealth ``BaseModel`` implementation. + + Expected task contract: + - input field ``time_series`` containing per-sample history encoded as + ``[T, 3F]`` with interleaved [value, mask, decay] channels + - input field ``static`` containing optional static features + - output field ``target_los_hours`` declared as ``"regression"`` + + The model predicts one scalar remaining length-of-stay value for each + sample using the full observed history in that sample. + + Args: + dataset: PyHealth ``SampleDataset`` produced by ``dataset.set_task()``. + input_dim: Number of base time-series features ``F`` before expansion + into [value, mask, decay]. + static_dim: Static feature dimension. + temporal_channels: Temporal branch output channels per layer. + pointwise_channels: Pointwise branch output channels per layer. + num_layers: Number of stacked TPC layers. + kernel_size: Temporal convolution kernel size. + fc_dim: Hidden dimension in the final regression head. + dropout: Dropout probability. + loss_name: Loss function name. Use ``"msle"`` for mean squared + logarithmic error or ``"mse"`` for mean squared error. + use_decay_in_pointwise: Whether to inject decay channels into the + pointwise branch. + positive_output: Whether to enforce non-negative predictions with + ``Softplus``. + shared_temporal: Whether temporal filters are shared across features. + use_temporal: Whether to enable the temporal branch. + use_pointwise: Whether to enable the pointwise branch. + use_skip_connections: Whether to use concatenative skip connections. + + Notes: + - This model follows the standard PyHealth ``forward(**kwargs)`` + contract. + - The output head size is derived from + ``BaseModel.get_output_size()``. + """ + + def __init__( + self, + dataset, + input_dim: int, + static_dim: int = 0, + temporal_channels: int = 8, + pointwise_channels: int = 8, + num_layers: int = 3, + kernel_size: int = 3, + fc_dim: int = 32, + dropout: float = 0.1, + loss_name: str = "msle", + use_decay_in_pointwise: bool = True, + positive_output: bool = True, + shared_temporal: bool = False, + use_temporal: bool = True, + use_pointwise: bool = True, + use_skip_connections: bool = True, + ) -> None: + """Initialize the TPC model.""" + super().__init__(dataset=dataset) + + if input_dim <= 0: + raise ValueError("input_dim must be positive") + if static_dim < 0: + raise ValueError("static_dim must be non-negative") + if temporal_channels <= 0: + raise ValueError("temporal_channels must be positive") + if pointwise_channels <= 0: + raise ValueError("pointwise_channels must be positive") + if num_layers <= 0: + raise ValueError("num_layers must be positive") + if kernel_size <= 0: + raise ValueError("kernel_size must be positive") + if fc_dim <= 0: + raise ValueError("fc_dim must be positive") + if loss_name not in {"msle", "mse"}: + raise ValueError("loss_name must be one of {'msle', 'mse'}") + + self.label_key = self.label_keys[0] + + required_feature_keys = {"time_series", "static"} + missing = required_feature_keys.difference(set(self.feature_keys)) + if missing: + raise ValueError( + "TPC requires task input_schema to contain the feature keys " + f"{sorted(required_feature_keys)}; missing {sorted(missing)}" + ) + + self.input_dim = input_dim + self.static_dim = static_dim + self.temporal_channels = temporal_channels + self.pointwise_channels = pointwise_channels + self.num_layers = num_layers + self.kernel_size = kernel_size + self.fc_dim = fc_dim + self.dropout = dropout + self.loss_name = loss_name + self.use_decay_in_pointwise = use_decay_in_pointwise + self.positive_output = positive_output + self.shared_temporal = shared_temporal + self.use_temporal = use_temporal + self.use_pointwise = use_pointwise + self.use_skip_connections = use_skip_connections + + layers = [] + in_channels = 2 # value + decay + + for layer_idx in range(num_layers): + layer = TPCLayer( + num_features=input_dim, + in_channels=in_channels, + temporal_channels=temporal_channels, + pointwise_channels=pointwise_channels, + static_dim=static_dim, + kernel_size=kernel_size, + dilation=layer_idx + 1, + dropout=dropout, + use_decay_in_pointwise=use_decay_in_pointwise, + shared_temporal=shared_temporal, + use_temporal=use_temporal, + use_pointwise=use_pointwise, + use_skip_connections=use_skip_connections, + ) + layers.append(layer) + in_channels = layer.output_channels + + self.layers = nn.ModuleList(layers) + + final_input_dim = (input_dim * in_channels) + static_dim + self.final_fc1 = nn.Linear(final_input_dim, fc_dim) + self.final_fc2 = nn.Linear(fc_dim, self.get_output_size()) + + self.relu = nn.ReLU() + self.softplus = nn.Softplus() + + def _unpack_feature_value(self, key: str, feature: Any) -> torch.Tensor: + """Extract the processor 'value' tensor for a feature key. + + PyHealth passes either: + - a raw tensor, or + - a tuple aligned with the processor schema + + Args: + key: Feature key in the task schema. + feature: Batch feature value from ``kwargs``. + + Returns: + The tensor corresponding to the processor's ``value`` field. + + Raises: + ValueError: If the processor schema does not contain ``value``. + """ + if isinstance(feature, torch.Tensor): + return feature + + if not isinstance(feature, (tuple, list)): + raise ValueError( + f"Expected feature '{key}' to be a Tensor or tuple/list, " + f"got {type(feature).__name__}" + ) + + schema = self.dataset.input_processors[key].schema() + if "value" not in schema: + raise ValueError( + f"Processor schema for feature '{key}' does not contain 'value': " + f"{schema}" + ) + return feature[schema.index("value")] + + def _split_value_mask_decay( + self, + time_series: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split a batch time series tensor into values, masks, and decay. + + Supported input layouts: + - ``[B, T, 3F]`` with interleaved [value, mask, decay] + - ``[B, T, F, 3]`` with final channel order [value, mask, decay] + + Args: + time_series: Batch time series tensor. + + Returns: + Tuple ``(values, masks, decay)`` each with shape ``[B, T, F]``. + + Raises: + ValueError: If the tensor shape is incompatible. + """ + if time_series.ndim == 3: + batch_size, seq_len, feat_dim = time_series.shape + if feat_dim % 3 != 0: + raise ValueError( + "Expected time_series last dimension divisible by 3 for " + f"[value, mask, decay], got {feat_dim}" + ) + num_features = feat_dim // 3 + if num_features != self.input_dim: + raise ValueError( + f"Expected input_dim={self.input_dim}, got {num_features}" + ) + + values = [] + masks = [] + decay = [] + for feature_idx in range(num_features): + base = feature_idx * 3 + values.append(time_series[:, :, base].unsqueeze(-1)) + masks.append(time_series[:, :, base + 1].unsqueeze(-1)) + decay.append(time_series[:, :, base + 2].unsqueeze(-1)) + + return ( + torch.cat(values, dim=-1), + torch.cat(masks, dim=-1), + torch.cat(decay, dim=-1), + ) + + if time_series.ndim == 4: + batch_size, seq_len, num_features, channels = time_series.shape + if channels != 3: + raise ValueError( + "Expected time_series channel dimension size 3 for " + f"[value, mask, decay], got {channels}" + ) + if num_features != self.input_dim: + raise ValueError( + f"Expected input_dim={self.input_dim}, got {num_features}" + ) + values = time_series[:, :, :, 0] + masks = time_series[:, :, :, 1] + decay = time_series[:, :, :, 2] + return values, masks, decay + + raise ValueError( + "Expected time_series to have shape [B, T, 3F] or [B, T, F, 3], " + f"got {tuple(time_series.shape)}" + ) + + def forward(self, **kwargs) -> dict[str, torch.Tensor]: + """Run the TPC model forward under the PyHealth BaseModel contract. + + Expected kwargs: + - ``time_series``: tensor or processor tuple containing the + [value, mask, decay] history representation + - ``static``: static feature tensor or processor tuple + - ``target_los_hours``: regression targets + + Returns: + Dictionary containing: + - ``loss``: scalar regression loss + - ``y_prob``: prepared output probabilities/identity values + - ``y_true``: ground-truth labels + - ``logit``: raw model outputs + + Raises: + ValueError: If required inputs are missing or malformed. + """ + if "time_series" not in kwargs: + raise ValueError("Missing required batch field 'time_series'") + if "static" not in kwargs: + raise ValueError("Missing required batch field 'static'") + if self.label_key not in kwargs: + raise ValueError( + f"Missing required label field '{self.label_key}' in batch" + ) + + time_series = self._unpack_feature_value("time_series", kwargs["time_series"]) + static = self._unpack_feature_value("static", kwargs["static"]) + + if not isinstance(time_series, torch.Tensor): + raise ValueError("'time_series' value must be a Tensor after unpacking") + if not isinstance(static, torch.Tensor): + raise ValueError("'static' value must be a Tensor after unpacking") + + time_series = time_series.float().to(self.device) + static = static.float().to(self.device) + + if time_series.ndim not in {3, 4}: + raise ValueError( + "Expected time_series batch tensor to have 3 or 4 dimensions, " + f"got {tuple(time_series.shape)}" + ) + + if static.ndim != 2: + raise ValueError( + f"Expected static to have shape [B, S], got {tuple(static.shape)}" + ) + if static.shape[1] != self.static_dim: + raise ValueError( + f"Expected static_dim={self.static_dim}, got {static.shape[1]}" + ) + + x_values, _, x_decay = self._split_value_mask_decay(time_series) + + x = torch.stack([x_values, x_decay], dim=-1) # [B, T, F, 2] + + for layer in self.layers: + x = layer(x, decay=x_decay, static=static) + + batch_size, seq_len, _, _ = x.shape + last_x = x[:, -1, :, :].reshape(batch_size, -1) + + if static is not None: + last_x = torch.cat([last_x, static], dim=-1) + + hidden = self.relu(self.final_fc1(last_x)) + logits = self.final_fc2(hidden) + + if self.positive_output: + logits = self.softplus(logits) + + y_true = kwargs[self.label_key].float().to(self.device) + if y_true.ndim == 1: + y_true = y_true.unsqueeze(-1) + elif y_true.ndim > 2: + raise ValueError( + f"Expected y_true to have shape [B] or [B, 1], got {tuple(y_true.shape)}" + ) + + if self.loss_name == "msle": + loss = F.mse_loss(torch.log1p(logits), torch.log1p(y_true)) + elif self.loss_name == "mse": + loss = F.mse_loss(logits, y_true) + else: + raise ValueError( + f"Unsupported loss_name '{self.loss_name}'. Expected 'msle' or 'mse'." + ) + + return { + "loss": loss, + "y_prob": self.prepare_y_prob(logits), + "y_true": y_true, + "logit": logits, + } \ No newline at end of file diff --git a/pyhealth/tasks/hourly_los.py b/pyhealth/tasks/hourly_los.py new file mode 100644 index 000000000..db0002c2c --- /dev/null +++ b/pyhealth/tasks/hourly_los.py @@ -0,0 +1,1060 @@ +""" +Hourly ICU remaining length-of-stay (LoS) task implementation for PyHealth. + +This module defines the ``HourlyLOSEICU`` task, which constructs supervised +samples for predicting the remaining ICU length of stay at hourly intervals +from multivariate EHR time-series data. + +The implementation is designed to align with the preprocessing strategy +described in: + +Rocheteau, E., Liò, P., and Hyland, S. (2021). +"Temporal Pointwise Convolutional Networks for Length of Stay Prediction +in the Intensive Care Unit." + +Overview: + For each ICU stay, this task: + + 1. Extracts observations from both pre-ICU and ICU time windows + (typically up to 24 hours before ICU admission). + 2. Buckets observations into hourly intervals. + 3. Retains the most recent measurement within each hour. + 4. Applies forward-filling to handle missing values. + 5. Computes decay features to represent time since last observation. + 6. Removes pre-ICU rows after feature construction. + 7. Emits one supervised sample per prediction hour. + +Task Contract: + Declared input schema: + - ``time_series``: processor-backed timeseries field encoding + per-hour [value, mask, decay] channels for each feature + - ``static``: processor-backed tensor field for numeric and + one-hot encoded static features + + Declared output schema: + - ``target_los_hours``: scalar regression target for remaining + ICU length of stay at the current prediction hour + +Metadata retained in each sample for analysis/debugging: + - ``target_los_sequence`` + - ``feature_names`` + - ``history_hours`` + - ``categorical_static_raw`` + - ``diagnosis_raw`` + +Implementation Notes: + - The public training label is ``target_los_hours``. + - ``target_los_sequence`` is retained only as auxiliary metadata and is + not part of the formal output schema. + - The generated ``time_series`` representation uses [value, mask, decay] + channel order for each feature, which the TPC model unpacks internally. +""" + +from __future__ import annotations + +import math +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple + +from .base_task import BaseTask + + +class HourlyLOSEICU(BaseTask): + """Hourly remaining ICU length-of-stay regression task. + + This task builds hourly time-series samples for ICU remaining length-of-stay + prediction from either eICU-style or MIMIC-IV-style patient records. + + For each ICU stay, the task: + 1. Extracts observations from up to ``pre_icu_hours`` before ICU + admission through ICU discharge. + 2. Buckets measurements into hourly bins. + 3. Keeps the most recent measurement within each hour. + 4. Forward-fills missing values. + 5. Computes decay features based on hours since the last true + observation. + 6. Drops pre-ICU rows after forward-fill so the final sequence is + ICU-only. + 7. Produces one supervised sample per prediction hour with a scalar + remaining-LoS label. + + Declared sample fields: + Inputs: + - ``time_series``: shape [T, 3F] with per-feature + [value, mask, decay] + - ``static``: encoded static feature vector + + Outputs: + - ``target_los_hours``: scalar remaining ICU LoS target + + Notes: + - This task is intended to serve as the formal task/model contract for + the TPC ``BaseModel`` implementation. + - Categorical static features are one-hot encoded with task-local + vocabularies. + - Diagnosis strings may be retained as metadata when requested. + """ + + task_name: str = "HourlyLOSEICU" + + def __init__( + self, + diagnosis_tables: Optional[List[str]] = None, + include_diagnoses: bool = False, + diagnosis_time_limit_hours: int = 5, + time_series_tables: Optional[List[str]] = None, + time_series_features: Optional[Dict[str, List[str]]] = None, + numeric_static_features: Optional[List[str]] = None, + categorical_static_features: Optional[List[str]] = None, + min_history_hours: int = 5, + max_hours: int = 48, + pre_icu_hours: int = 24, + ) -> None: + """Initialize the hourly ICU LoS task. + + Args: + diagnosis_tables: Diagnosis-related tables to inspect for diagnosis + extraction, primarily for eICU. + include_diagnoses: Whether to extract and attach diagnosis strings + as metadata. + diagnosis_time_limit_hours: Maximum diagnosis time from ICU + admission to include, in hours. + time_series_tables: Tables containing time-series observations. + time_series_features: Mapping from table name to the list of + feature names to extract from that table. + numeric_static_features: Static numeric features to encode directly. + categorical_static_features: Static categorical features to one-hot + encode using task-local vocabularies. + min_history_hours: Minimum history length required before a sample + is emitted. + max_hours: Maximum ICU hours to keep from a stay. + pre_icu_hours: Number of hours before ICU admission to include + during extraction and forward-fill before cropping back to + ICU-only rows. + """ + self.diagnosis_tables = diagnosis_tables or [] + self.include_diagnoses = include_diagnoses + self.diagnosis_time_limit_hours = diagnosis_time_limit_hours + self.time_series_tables = time_series_tables or ["lab"] + self.time_series_features = time_series_features or {"lab": ["-basos"]} + self.numeric_static_features = numeric_static_features or [] + self.categorical_static_features = categorical_static_features or [] + self.min_history_hours = min_history_hours + self.max_hours = max_hours + self.pre_icu_hours = pre_icu_hours + + self.static_vocab: Dict[str, Dict[Any, int]] = { + feature: {} for feature in self.categorical_static_features + } + + # Formal task/model contract. + self.input_schema = { + "time_series": "tensor", + "static": "tensor", + } + + self.output_schema = { + "target_los_hours": "regression", + } + + def pre_filter(self, global_event_df): + """Optionally filter the global event dataframe before task processing. + + Args: + global_event_df: The global event dataframe prepared by the dataset. + + Returns: + The input dataframe unchanged. + """ + return global_event_df + + def _split_hierarchical_diagnosis(self, raw: Any) -> List[str]: + """Split a hierarchical diagnosis string into cumulative prefix tokens. + + Example: + ``"a | b | c" -> ["a", "a | b", "a | b | c"]`` + + Args: + raw: Raw diagnosis string or other value. + + Returns: + A list of cumulative hierarchical diagnosis prefixes. + """ + if raw is None: + return [] + + s = str(raw).strip() + if not s: + return [] + + parts = [p.strip().lower() for p in s.split("|") if p.strip()] + if not parts: + return [] + + prefixes = [] + current = [] + for part in parts: + current.append(part) + prefixes.append(" | ".join(current)) + return prefixes + + def _extract_eicu_diagnoses(self, patient) -> List[str]: + """Extract hierarchical diagnosis tokens for an eICU patient. + + Only diagnoses within ``diagnosis_time_limit_hours`` are included. + + Args: + patient: PyHealth patient object. + + Returns: + A sorted list of unique diagnosis tokens. + """ + if not self.include_diagnoses: + return [] + + diagnosis_tokens = set() + + for table in self.diagnosis_tables: + try: + events = patient.get_events(table) + except Exception: + events = [] + + for event in events: + attr = event.attr_dict + + offset_minutes = None + for offset_key in [ + "pasthistoryoffset", + "admitdxenteredoffset", + "diagnosisoffset", + ]: + offset_minutes = self._safe_float(attr.get(offset_key)) + if offset_minutes is not None: + break + + if offset_minutes is None: + continue + + offset_hours = offset_minutes / 60.0 + if offset_hours > self.diagnosis_time_limit_hours: + continue + + raw_candidates = [ + attr.get("pasthistorypath"), + attr.get("admitdxpath"), + attr.get("diagnosisstring"), + attr.get("diagnosispath"), + ] + + raw_value = None + for candidate in raw_candidates: + if candidate is not None and str(candidate).strip(): + raw_value = candidate + break + + if raw_value is None: + continue + + for token in self._split_hierarchical_diagnosis(raw_value): + diagnosis_tokens.add(token) + + return sorted(diagnosis_tokens) + + def _safe_float(self, value: Any) -> Optional[float]: + """Convert a value to a finite float if possible. + + Args: + value: Input value. + + Returns: + A finite float, or ``None`` if conversion fails or the value is + NaN/inf. + """ + try: + if value is None: + return None + x = float(value) + if math.isnan(x) or math.isinf(x): + return None + return x + except Exception: + return None + + def _safe_datetime(self, value: Any) -> Optional[datetime]: + """Convert a value to a datetime if possible. + + Supports ISO strings and several common datetime formats. + + Args: + value: Input value. + + Returns: + A datetime object, or ``None`` if parsing fails. + """ + if value is None: + return None + if isinstance(value, datetime): + return value + + s = str(value).strip() + if not s: + return None + + s = s.replace("Z", "+00:00") + try: + return datetime.fromisoformat(s) + except Exception: + pass + + fmts = [ + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M:%S.%f", + "%Y-%m-%d", + ] + for fmt in fmts: + try: + return datetime.strptime(s, fmt) + except Exception: + continue + return None + + def _norm_name(self, x: Any) -> str: + """Normalize a feature name to a lowercase stripped string. + + Args: + x: Raw feature name. + + Returns: + Normalized feature name string. + """ + if x is None: + return "" + return str(x).strip().lower() + + def _encode_static(self, attr: Dict[str, Any]) -> List[float]: + """Encode static numeric and categorical features into a flat vector. + + Numeric features are inserted directly after safe float conversion. + Categorical features are one-hot encoded with a task-local vocabulary. + + Args: + attr: Source attribute dictionary. + + Returns: + Encoded static feature vector. + """ + numeric = [] + categorical = [] + + for feature_name in self.numeric_static_features: + val = self._safe_float(attr.get(feature_name)) + if val is None: + val = 0.0 + numeric.append(float(val)) + + for feature_name in self.categorical_static_features: + raw = attr.get(feature_name) + category = "__MISSING__" if raw is None else str(raw) + + vocab = self.static_vocab.setdefault(feature_name, {}) + if category not in vocab: + vocab[category] = len(vocab) + + one_hot = [0.0] * len(vocab) + one_hot[vocab[category]] = 1.0 + categorical.extend(one_hot) + + return numeric + categorical + + def _build_feature_index(self) -> Tuple[List[str], Dict[str, int]]: + """Build the normalized feature list and index mapping. + + Returns: + A tuple of: + - normalized feature names in deterministic order + - mapping from feature name to feature index + + Raises: + ValueError: If no time-series features are defined. + """ + names = [] + for table in self.time_series_tables: + names.extend(self.time_series_features.get(table, [])) + + normalized_names = [self._norm_name(name) for name in names] + + if len(normalized_names) == 0: + raise ValueError("No time-series features defined.") + + if len(set(normalized_names)) != len(normalized_names): + seen = set() + deduped = [] + for name in normalized_names: + if name not in seen: + seen.add(name) + deduped.append(name) + normalized_names = deduped + + return normalized_names, {name: i for i, name in enumerate(normalized_names)} + + def _combine_value_mask_decay( + self, + filled: List[List[float]], + mask: List[List[float]], + decay: List[List[float]], + ) -> List[List[float]]: + """Interleave filled values, masks, and decay as [value, mask, decay]. + + Args: + filled: Forward-filled values, shape [T, F]. + mask: Observation mask, shape [T, F]. + decay: Decay features, shape [T, F]. + + Returns: + Combined feature matrix of shape [T, 3F]. + """ + combined = [] + for hour_idx in range(len(filled)): + row = [] + for feat_idx in range(len(filled[hour_idx])): + row.append(filled[hour_idx][feat_idx]) + row.append(mask[hour_idx][feat_idx]) + row.append(decay[hour_idx][feat_idx]) + combined.append(row) + return combined + + def _build_feature_names(self, normalized_names: List[str]) -> List[str]: + """Build output feature names for value, mask, and decay channels. + + Args: + normalized_names: Base time-series feature names. + + Returns: + Expanded feature names in + ``[name_val, name_mask, name_decay]`` order. + """ + feature_names = [] + for name in normalized_names: + feature_names.extend( + [f"{name}_val", f"{name}_mask", f"{name}_decay"] + ) + return feature_names + + def _make_hourly_tensor( + self, + observations: List[Tuple[int, int, float, float]], + usable_hours: int, + num_features: int, + ) -> List[List[float]]: + """Build an hourly [value, mask, decay] tensor from observations. + + Each observation tuple is: + ``(hour_index, feature_index, value, precise_offset_hours)`` + + For each hour and feature, the most recent measurement within that hour + is retained. Missing values are forward-filled. Mask indicates whether + the value was observed in that exact hour. Decay follows ``0.75 ** j`` + where ``j`` is the number of hours since the last real observation. + + Args: + observations: Bucketed observations. + usable_hours: Number of hours in the timeline being built. + num_features: Number of time-series features. + + Returns: + Combined [T, 3F] tensor as a Python list of rows. + """ + latest_vals = [[None] * num_features for _ in range(usable_hours)] + latest_time = [[-float("inf")] * num_features for _ in range(usable_hours)] + + for hour_idx, feat_idx, value, precise_offset in observations: + if hour_idx < 0 or hour_idx >= usable_hours: + continue + if precise_offset >= latest_time[hour_idx][feat_idx]: + latest_time[hour_idx][feat_idx] = precise_offset + latest_vals[hour_idx][feat_idx] = value + + filled = [] + mask = [] + decay = [] + + last_val = [0.0] * num_features + last_seen = [None] * num_features + + for hour_idx in range(usable_hours): + filled_row = [] + mask_row = [] + decay_row = [] + + for feat_idx in range(num_features): + value = latest_vals[hour_idx][feat_idx] + + if value is not None: + last_val[feat_idx] = value + last_seen[feat_idx] = hour_idx + filled_row.append(value) + mask_row.append(1.0) + decay_row.append(1.0) + else: + filled_row.append(last_val[feat_idx]) + mask_row.append(0.0) + + if last_seen[feat_idx] is None: + decay_row.append(0.0) + else: + gap = hour_idx - last_seen[feat_idx] + decay_row.append(float(0.75 ** gap)) + + filled.append(filled_row) + mask.append(mask_row) + decay.append(decay_row) + + return self._combine_value_mask_decay(filled, mask, decay) + + def _make_cropped_hourly_tensor( + self, + observations: List[Tuple[int, int, float, float]], + total_hours: float, + num_features: int, + ) -> List[List[float]]: + """Build an extended timeline tensor and crop it to ICU-only rows. + + The timeline begins ``pre_icu_hours`` before ICU admission, allowing + pre-ICU observations to participate in hourly bucketing and forward-fill. + After the extended tensor is built, the leading pre-ICU rows are removed. + + Args: + observations: Observations indexed on an extended timeline whose + zero point is ICU admission minus ``self.pre_icu_hours``. + total_hours: ICU stay length in hours. + num_features: Number of time-series features. + + Returns: + Cropped ICU-only [T, 3F] tensor as a Python list of rows. + """ + usable_hours = int(min(total_hours, self.max_hours)) + extended_hours = self.pre_icu_hours + usable_hours + + full_ts = self._make_hourly_tensor( + observations=observations, + usable_hours=extended_hours, + num_features=num_features, + ) + + return full_ts[self.pre_icu_hours : self.pre_icu_hours + usable_hours] + + def _make_samples_for_stay( + self, + patient, + visit_id: Any, + total_hours: float, + static_attr: Dict[str, Any], + observations: List[Tuple[int, int, float, float]], + normalized_feature_names: List[str], + diagnosis_raw: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: + """Create hourly supervised samples for a single ICU stay. + + Each emitted sample uses history through hour ``t`` and predicts the + scalar remaining length of stay at that hour. + + Args: + patient: PyHealth patient object. + visit_id: Stay or visit identifier. + total_hours: Total ICU stay length in hours. + static_attr: Static attributes to encode. + observations: Extended-timeline observations. + normalized_feature_names: Base feature names. + diagnosis_raw: Optional diagnosis tokens. + + Returns: + A list of task samples, one for each prediction hour from + ``min_history_hours`` through ``usable_hours``. + """ + samples = [] + + if total_hours is None or total_hours < self.min_history_hours: + return samples + + usable_hours = int(min(total_hours, self.max_hours)) + if usable_hours < self.min_history_hours: + return samples + + static_vec = self._encode_static(static_attr) + num_features = len(normalized_feature_names) + + time_series = self._make_cropped_hourly_tensor( + observations=observations, + total_hours=total_hours, + num_features=num_features, + ) + feature_names = self._build_feature_names(normalized_feature_names) + + raw_cats = {} + for feature_name in self.categorical_static_features: + raw_cats[feature_name] = ( + str(static_attr.get(feature_name)) + if static_attr.get(feature_name) is not None + else "__MISSING__" + ) + + for history_hours in range(self.min_history_hours, usable_hours + 1): + remaining_hours = max(total_hours - history_hours, 0.0) + + target_los_sequence = [ + float(max(total_hours - hour_idx, 0.0)) + for hour_idx in range(1, history_hours + 1) + ] + + samples.append( + { + "patient_id": str(patient.patient_id), + "visit_id": str(visit_id) if visit_id is not None else "unknown", + "time_series": time_series[:history_hours], + "static": [float(x) for x in static_vec], + "target_los_hours": float(remaining_hours), + "target_los_sequence": [float(x) for x in target_los_sequence], + "history_hours": int(history_hours), + } + ) + + return samples + + def _build_eicu_samples(self, patient) -> List[Dict[str, Any]]: + """Build hourly LoS samples for an eICU patient. + + This method extracts time-series observations using eICU offset fields, + shifts them onto an extended timeline that includes pre-ICU hours, + and creates ICU-only samples after cropping. + + Args: + patient: PyHealth patient object. + + Returns: + List of generated task samples. + """ + samples = [] + + try: + patient_events = patient.get_events("patient") + except Exception: + patient_events = [] + + if not patient_events: + return samples + + anchor_attr = patient_events[0].attr_dict + total_minutes = self._safe_float(anchor_attr.get("unitdischargeoffset")) + if total_minutes is None: + return samples + + total_hours = total_minutes / 60.0 + if total_hours < self.min_history_hours: + return samples + + usable_hours = int(min(total_hours, self.max_hours)) + if usable_hours < self.min_history_hours: + return samples + + normalized_names, feature_index = self._build_feature_index() + if not normalized_names: + return samples + + extra_features = ["time in the icu", "time of day"] + for feature_name in extra_features: + if feature_name not in normalized_names: + normalized_names.append(feature_name) + + feature_index = { + feature_name: idx for idx, feature_name in enumerate(normalized_names) + } + + observations = [] + + for table in self.time_series_tables: + try: + events = patient.get_events(table) + except Exception: + events = [] + + schema = self._get_eicu_table_schema(table) + if schema is None: + continue + + name_key, value_key, offset_key = schema + + for event in events: + attr = event.attr_dict + + if name_key is None and value_key is None: + minutes = self._safe_float(attr.get(offset_key)) + if minutes is None: + continue + + offset_hours = minutes / 60.0 + extended_hour = int(offset_hours) + self.pre_icu_hours + + for raw_name, raw_val in attr.items(): + norm_name = self._norm_name(raw_name) + if norm_name not in feature_index: + continue + + value = self._safe_float(raw_val) + if value is None: + continue + + feat_idx = feature_index[norm_name] + observations.append( + (extended_hour, feat_idx, value, offset_hours) + ) + + time_in_icu_idx = feature_index.get("time in the icu") + if time_in_icu_idx is not None: + observations.append( + ( + extended_hour, + time_in_icu_idx, + offset_hours, + offset_hours, + ) + ) + + time_of_day = None + admit_time_str = anchor_attr.get("hospitaladmittime24") + if admit_time_str: + try: + h, m, s = map(int, admit_time_str.split(":")) + base_hour = h + m / 60.0 + s / 3600.0 + time_of_day = (base_hour + offset_hours) % 24 + except Exception: + pass + + time_of_day_idx = feature_index.get("time of day") + if time_of_day_idx is not None and time_of_day is not None: + observations.append( + ( + extended_hour, + time_of_day_idx, + time_of_day, + offset_hours, + ) + ) + + continue + + name = self._norm_name(attr.get(name_key)) + if name not in feature_index: + continue + + value = self._safe_float(attr.get(value_key)) + minutes = self._safe_float(attr.get(offset_key)) + if value is None or minutes is None: + continue + + offset_hours = minutes / 60.0 + extended_hour = int(offset_hours) + self.pre_icu_hours + feat_idx = feature_index[name] + observations.append((extended_hour, feat_idx, value, offset_hours)) + + time_in_icu_idx = feature_index.get("time in the icu") + if time_in_icu_idx is not None: + observations.append( + ( + extended_hour, + time_in_icu_idx, + offset_hours, + offset_hours, + ) + ) + + time_of_day = None + admit_time_str = anchor_attr.get("hospitaladmittime24") + if admit_time_str: + try: + h, m, s = map(int, admit_time_str.split(":")) + base_hour = h + m / 60.0 + s / 3600.0 + time_of_day = (base_hour + offset_hours) % 24 + except Exception: + pass + + time_of_day_idx = feature_index.get("time of day") + if time_of_day_idx is not None and time_of_day is not None: + observations.append( + ( + extended_hour, + time_of_day_idx, + time_of_day, + offset_hours, + ) + ) + + visit_id = ( + anchor_attr.get("patientunitstayid") + or anchor_attr.get("visit_id") + or patient.patient_id + ) + + static_attr = dict(anchor_attr) + diagnosis_raw = self._extract_eicu_diagnoses(patient) + + samples.extend( + self._make_samples_for_stay( + patient=patient, + visit_id=visit_id, + total_hours=total_hours, + static_attr=static_attr, + observations=observations, + normalized_feature_names=normalized_names, + diagnosis_raw=diagnosis_raw, + ) + ) + + return samples + + def _get_eicu_table_schema( + self, + table: str, + ) -> Optional[Tuple[Optional[str], Optional[str], str]]: + """Return the schema tuple for supported eICU time-series tables. + + Args: + table: Table name. + + Returns: + A tuple of ``(name_key, value_key, offset_key)``, or ``None`` if the + table is unsupported. + """ + table = str(table).strip().lower() + + schema = { + "lab": ("labname", "labresult", "labresultoffset"), + "respiratorycharting": ( + "respchartvaluelabel", + "respchartvalue", + "respchartoffset", + ), + "nursecharting": ( + "nursingchartcelltypevallabel", + "nursingchartvalue", + "nursingchartoffset", + ), + "vitalperiodic": (None, None, "observationoffset"), + "vitalaperiodic": (None, None, "observationoffset"), + } + + return schema.get(table) + + def _get_mimic_table_schema( + self, + table: str, + ) -> Optional[Tuple[str, Optional[str], str]]: + """Return the schema tuple for supported MIMIC-IV time-series tables. + + Args: + table: Table name. + + Returns: + A tuple of ``(name_key, value_key, time_key)``, or ``None`` if the + table is unsupported. + """ + table = str(table).strip().lower() + + schema = { + "labevents": ("label", None, "timestamp"), + "chartevents": ("label", None, "timestamp"), + } + + return schema.get(table) + + def _build_mimic_samples(self, patient) -> List[Dict[str, Any]]: + """Build hourly LoS samples for a MIMIC-IV patient. + + This method extracts time-series observations from up to + ``pre_icu_hours`` before ICU admission through ICU discharge, shifts + them onto an extended timeline, and creates ICU-only samples after + cropping. + + Args: + patient: PyHealth patient object. + + Returns: + List of generated task samples. + """ + samples = [] + + try: + patient_rows = patient.get_events("patients") + except Exception: + patient_rows = [] + + try: + admission_rows = patient.get_events("admissions") + except Exception: + admission_rows = [] + + try: + icu_rows = patient.get_events("icustays") + except Exception: + icu_rows = [] + + if not icu_rows: + return samples + + patient_static = patient_rows[0].attr_dict if patient_rows else {} + + admissions_by_hadm = {} + for admission in admission_rows: + hadm_id = admission.attr_dict.get("hadm_id") + if hadm_id is not None and hadm_id not in admissions_by_hadm: + admissions_by_hadm[hadm_id] = admission + + normalized_names, feature_index = self._build_feature_index() + if not normalized_names: + return samples + + extra_features = ["time in the icu", "time of day"] + for feature_name in extra_features: + if feature_name not in normalized_names: + normalized_names.append(feature_name) + + feature_index = { + feature_name: idx for idx, feature_name in enumerate(normalized_names) + } + + for icu_event in icu_rows: + icu_attr = dict(icu_event.attr_dict) + hadm_id = icu_attr.get("hadm_id") + stay_id = icu_attr.get("stay_id") + + intime = getattr(icu_event, "timestamp", None) + outtime = self._safe_datetime(icu_attr.get("outtime")) + + if intime is None or outtime is None: + continue + + total_hours = (outtime - intime).total_seconds() / 3600.0 + if total_hours < self.min_history_hours: + continue + + static_attr = dict(patient_static) + if hadm_id in admissions_by_hadm: + static_attr.update(admissions_by_hadm[hadm_id].attr_dict) + static_attr.update(icu_attr) + + observations = [] + pre_icu_start = intime - timedelta(hours=self.pre_icu_hours) + + for table in self.time_series_tables: + try: + events = patient.get_events(table) + except Exception: + events = [] + + schema = self._get_mimic_table_schema(table) + if schema is None: + continue + + name_key, _, _ = schema + + for event in events: + attr = event.attr_dict + + event_hadm_id = attr.get("hadm_id") + if ( + hadm_id is not None + and event_hadm_id is not None + and str(event_hadm_id) != str(hadm_id) + ): + continue + + name = self._norm_name(attr.get(name_key)) + if name not in feature_index: + continue + + value = self._safe_float(attr.get("valuenum")) + if value is None: + value = self._safe_float(attr.get("value")) + if value is None: + continue + + event_time = getattr(event, "timestamp", None) + if event_time is None: + event_time = self._safe_datetime(attr.get("storetime")) + if event_time is None: + event_time = self._safe_datetime(attr.get("charttime")) + if event_time is None: + continue + + if event_time < pre_icu_start or event_time > outtime: + continue + + offset_hours = ( + event_time - intime + ).total_seconds() / 3600.0 + extended_hour = int(offset_hours) + self.pre_icu_hours + feat_idx = feature_index[name] + observations.append((extended_hour, feat_idx, value, offset_hours)) + + time_in_icu_idx = feature_index.get("time in the icu") + if time_in_icu_idx is not None: + observations.append( + ( + extended_hour, + time_in_icu_idx, + offset_hours, + offset_hours, + ) + ) + + time_of_day_idx = feature_index.get("time of day") + if time_of_day_idx is not None: + time_of_day = ( + event_time.hour + + event_time.minute / 60.0 + + event_time.second / 3600.0 + ) + observations.append( + ( + extended_hour, + time_of_day_idx, + time_of_day, + offset_hours, + ) + ) + + samples.extend( + self._make_samples_for_stay( + patient=patient, + visit_id=stay_id or hadm_id or patient.patient_id, + total_hours=total_hours, + static_attr=static_attr, + observations=observations, + normalized_feature_names=normalized_names, + ) + ) + + return samples + + def __call__(self, patient): + """Dispatch task construction based on patient contents. + + Args: + patient: PyHealth patient object. + + Returns: + A list of generated task samples for the patient. + """ + try: + if patient.get_events("patient"): + return self._build_eicu_samples(patient) + except Exception: + pass + + try: + if patient.get_events("icustays"): + return self._build_mimic_samples(patient) + except Exception: + pass + + return [] \ No newline at end of file diff --git a/tests/models/test_tpc.py b/tests/models/test_tpc.py new file mode 100644 index 000000000..4155c8788 --- /dev/null +++ b/tests/models/test_tpc.py @@ -0,0 +1,336 @@ +""" +Unit tests for the Temporal Pointwise Convolution (TPC) PyHealth model. + +This module contains fast synthetic tests for validating the dataset-backed +``BaseModel`` implementation of TPC. The tests are designed to run quickly +without requiring real eICU or MIMIC-IV data. + +Overview: + The test suite checks: + + 1. Model instantiation under a dataset-backed BaseModel contract. + 2. Forward-pass correctness and required output keys. + 3. Output shapes for scalar regression. + 4. Forward-pass behavior across major ablation variants. + 5. Gradient propagation through the model during backpropagation. + 6. Failure behavior when required batch fields are missing. + +Implementation Notes: + - Uses only small synthetic tensors for speed and reproducibility. + - Uses a minimal mocked dataset/processor contract to exercise the model's + true BaseModel-facing path. + - Does not depend on real datasets. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + +import pytest +import torch +import torch.nn as nn + +import pyhealth.models.tpc as tpc_module +from pyhealth.models.tpc import TPC + + +@dataclass +class DummyProcessor: + """Minimal processor stub exposing a schema with a ``value`` field.""" + + names: List[str] + + def schema(self) -> List[str]: + """Return the processor schema.""" + return self.names + + +class DummyDataset: + """Minimal dataset stub matching the fields TPC expects from BaseModel. + + This object is intentionally lightweight. It provides only the fields + needed by the rewritten TPC model and the patched BaseModel methods in + this test module. + """ + + def __init__(self) -> None: + """Initialize the dummy dataset.""" + self.input_schema = { + "time_series": "timeseries", + "static": "tensor", + } + self.output_schema = { + "target_los_hours": "regression", + } + self.input_processors: Dict[str, DummyProcessor] = { + "time_series": DummyProcessor(["value"]), + "static": DummyProcessor(["value"]), + } + self.output_processors = {} + + +@pytest.fixture() +def patch_basemodel(monkeypatch): + """Patch BaseModel helpers so TPC can be tested with synthetic inputs only. + + The project rubric requires fast synthetic model tests, and the current + test environment does not rely on the full trainer/runtime stack. + This fixture patches the minimum BaseModel behaviors the TPC model uses: + - dataset-backed initialization + - output size resolution + - regression loss resolution + - prediction preparation + """ + + def fake_init(self, dataset): + nn.Module.__init__(self) + self.dataset = dataset + self.feature_keys = list(dataset.input_schema.keys()) + self.label_keys = list(dataset.output_schema.keys()) + self.__dict__["device"] = torch.device("cpu") + + class DummyProcessor: + def size(self): + return 1 + + self.dataset.output_processors = { + self.label_keys[0]: DummyProcessor() + } + + monkeypatch.setattr(tpc_module.BaseModel, "__init__", fake_init) + monkeypatch.setattr(tpc_module.BaseModel, "get_output_size", lambda self: 1) + monkeypatch.setattr( + tpc_module.BaseModel, + "get_loss_function", + lambda self: nn.MSELoss(), + ) + monkeypatch.setattr( + tpc_module.BaseModel, + "prepare_y_prob", + lambda self, logit: logit, + ) + monkeypatch.setattr( + tpc_module.BaseModel, + "device", + property(lambda self: torch.device("cpu")), + ) + + +def make_batch( + batch_size: int = 2, + seq_len: int = 6, + num_features: int = 5, + static_dim: int = 3, +) -> Dict[str, torch.Tensor | tuple[torch.Tensor]]: + """Create a small synthetic batch for BaseModel-style testing. + + The ``time_series`` tensor uses the task/model contract shape [B, T, 3F] + with interleaved [value, mask, decay] channels per feature. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + num_features: Number of base time-series features. + static_dim: Static feature dimension. + + Returns: + A batch dictionary compatible with ``TPC.forward(**kwargs)``. + """ + values = torch.randn(batch_size, seq_len, num_features) + masks = torch.randint(0, 2, (batch_size, seq_len, num_features)).float() + decay = torch.rand(batch_size, seq_len, num_features) + + pieces = [] + for feature_idx in range(num_features): + pieces.append(values[:, :, feature_idx].unsqueeze(-1)) + pieces.append(masks[:, :, feature_idx].unsqueeze(-1)) + pieces.append(decay[:, :, feature_idx].unsqueeze(-1)) + time_series = torch.cat(pieces, dim=-1) # [B, T, 3F] + + static = torch.randn(batch_size, static_dim) + target = torch.rand(batch_size, 1) * 10.0 + + return { + "time_series": (time_series,), + "static": (static,), + "target_los_hours": target, + } + + +def test_tpc_init_requires_expected_feature_keys(patch_basemodel) -> None: + """Test that TPC instantiates with the expected dataset-backed contract.""" + dataset = DummyDataset() + + model = TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + ) + + assert model.label_key == "target_los_hours" + assert set(model.feature_keys) == {"time_series", "static"} + + +def test_tpc_forward_returns_required_basemodel_keys(patch_basemodel) -> None: + """Test that forward returns the required BaseModel output dictionary.""" + dataset = DummyDataset() + batch = make_batch() + + model = TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + ) + + outputs = model(**batch) + + assert set(outputs.keys()) == {"loss", "y_prob", "y_true", "logit"} + assert torch.is_tensor(outputs["loss"]) + assert outputs["y_prob"].shape == (2, 1) + assert outputs["y_true"].shape == (2, 1) + assert outputs["logit"].shape == (2, 1) + assert torch.isfinite(outputs["logit"]).all() + + +def test_tpc_forward_accepts_tensor_inputs_without_processor_tuple( + patch_basemodel, +) -> None: + """Test that forward accepts raw tensors as well as processor tuples.""" + dataset = DummyDataset() + batch = make_batch() + + raw_batch = { + "time_series": batch["time_series"][0], + "static": batch["static"][0], + "target_los_hours": batch["target_los_hours"], + } + + model = TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + ) + + outputs = model(**raw_batch) + assert outputs["logit"].shape == (2, 1) + assert torch.isfinite(outputs["logit"]).all() + + +@pytest.mark.parametrize( + "model_kwargs", + [ + {"use_temporal": True, "use_pointwise": True}, + {"use_temporal": True, "use_pointwise": False}, + {"use_temporal": False, "use_pointwise": True}, + {"use_temporal": True, "use_pointwise": True, "shared_temporal": True}, + {"use_temporal": True, "use_pointwise": True, "use_skip_connections": False}, + ], +) +def test_tpc_ablation_variants_forward( + patch_basemodel, + model_kwargs: dict, +) -> None: + """Test forward pass across major ablation variants.""" + dataset = DummyDataset() + batch = make_batch() + + model = TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + **model_kwargs, + ) + + outputs = model(**batch) + assert outputs["logit"].shape == (2, 1) + assert torch.isfinite(outputs["logit"]).all() + assert torch.isfinite(outputs["loss"]).all() + + +def test_tpc_requires_at_least_one_branch(patch_basemodel) -> None: + """Test that TPC rejects configuration with both branches disabled.""" + dataset = DummyDataset() + + with pytest.raises(ValueError): + TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + use_temporal=False, + use_pointwise=False, + ) + + +def test_tpc_missing_label_raises(patch_basemodel) -> None: + """Test that missing label data raises a clear error.""" + dataset = DummyDataset() + batch = make_batch() + batch.pop("target_los_hours") + + model = TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + ) + + with pytest.raises(ValueError, match="Missing required label field"): + model(**batch) + + +def test_tpc_backward_pass(patch_basemodel) -> None: + """Test that gradients propagate through the dataset-backed model path.""" + dataset = DummyDataset() + batch = make_batch() + + model = TPC( + dataset=dataset, + input_dim=5, + static_dim=3, + temporal_channels=4, + pointwise_channels=4, + num_layers=2, + kernel_size=3, + fc_dim=8, + ) + + outputs = model(**batch) + loss = outputs["loss"] + loss.backward() + + has_grad = any( + parameter.grad is not None + for parameter in model.parameters() + if parameter.requires_grad + ) + assert has_grad \ No newline at end of file diff --git a/tests/tasks/test_hourly_los.py b/tests/tasks/test_hourly_los.py new file mode 100644 index 000000000..475ba428d --- /dev/null +++ b/tests/tasks/test_hourly_los.py @@ -0,0 +1,289 @@ +""" +Unit tests for the Hourly ICU length-of-stay (LoS) task. + +This module contains synthetic unit tests for validating the behavior of the +``HourlyLOSEICU`` task implementation. The tests verify correct construction +of hourly time-series features, target generation, and dataset-specific +handling for both eICU- and MIMIC-IV-style inputs. + +Overview: + The test suite checks: + + 1. Formal task schema: + - ``time_series`` is declared as a timeseries input + - ``static`` is declared as a tensor input + - ``target_los_hours`` is declared as a regression output + + 2. Hourly time-series construction: + - Latest observation within each hour is retained + - Forward-filling of missing values + - Correct decay feature computation (0.75 ** j) + + 3. Pre-ICU handling: + - Inclusion of pre-ICU observations during processing + - Proper cropping of pre-ICU rows after feature construction + + 4. Sample generation: + - eICU-style patient processing (offset-based timestamps) + - MIMIC-IV-style patient processing (datetime-based timestamps) + - Presence and correctness of expected output fields + +Implementation Notes: + - Tests use lightweight synthetic data for speed and reproducibility. + - No dependency on real eICU or MIMIC-IV datasets. + - Designed to validate core preprocessing logic independent of model code. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +from pyhealth.tasks.hourly_los import HourlyLOSEICU + + +class DummyEvent: + """Minimal event object used for task unit tests.""" + + def __init__(self, attr_dict, timestamp=None): + """Initialize the dummy event. + + Args: + attr_dict: Event attribute dictionary. + timestamp: Optional event timestamp. + """ + self.attr_dict = attr_dict + self.timestamp = timestamp + + +class DummyPatient: + """Minimal patient object used for task unit tests.""" + + def __init__(self, patient_id, tables): + """Initialize the dummy patient. + + Args: + patient_id: Patient identifier. + tables: Mapping from table name to event list. + """ + self.patient_id = patient_id + self.tables = tables + + def get_events(self, table): + """Return events for the requested table. + + Args: + table: Table name. + + Returns: + List of events for that table. + """ + return self.tables.get(table, []) + + +def test_hourly_los_declares_expected_schema() -> None: + """Test that the task declares the expected BaseModel-facing schema.""" + task = HourlyLOSEICU( + time_series_tables=["lab"], + time_series_features={"lab": ["creatinine"]}, + ) + + assert task.input_schema == { + "time_series": "tensor", + "static": "tensor", + } + assert task.output_schema == { + "target_los_hours": "regression", + } + + +def test_make_hourly_tensor_keeps_latest_and_forward_fills() -> None: + """Test that the latest within-hour measurement is kept and gaps are filled.""" + task = HourlyLOSEICU( + time_series_tables=["lab"], + time_series_features={"lab": ["creatinine"]}, + min_history_hours=1, + max_hours=6, + ) + + observations = [ + (0, 0, 1.0, 0.1), + (0, 0, 2.0, 0.9), # later in same hour, should win + (2, 0, 3.0, 2.2), + ] + + time_series = task._make_hourly_tensor( + observations=observations, + usable_hours=4, + num_features=1, + ) + + assert time_series[0][0] == 2.0 + assert time_series[0][1] == 1.0 + assert time_series[1][0] == 2.0 + assert time_series[1][1] == 0.0 + assert time_series[2][0] == 3.0 + assert time_series[2][1] == 1.0 + + +def test_make_hourly_tensor_decay_behavior() -> None: + """Test that decay follows the expected ``0.75 ** j`` rule.""" + task = HourlyLOSEICU( + time_series_tables=["lab"], + time_series_features={"lab": ["creatinine"]}, + min_history_hours=1, + max_hours=6, + ) + + observations = [ + (0, 0, 5.0, 0.2), + ] + + time_series = task._make_hourly_tensor( + observations=observations, + usable_hours=4, + num_features=1, + ) + + assert time_series[0][2] == 1.0 + assert abs(time_series[1][2] - 0.75) < 1e-6 + assert abs(time_series[2][2] - (0.75 ** 2)) < 1e-6 + assert abs(time_series[3][2] - (0.75 ** 3)) < 1e-6 + + +def test_cropped_hourly_tensor_removes_pre_icu_rows() -> None: + """Test that pre-ICU rows are removed after extended-timeline fill.""" + task = HourlyLOSEICU( + time_series_tables=["lab"], + time_series_features={"lab": ["creatinine"]}, + min_history_hours=1, + max_hours=6, + pre_icu_hours=2, + ) + + observations = [ + (0, 0, 10.0, -2.0), + (1, 0, 11.0, -1.0), + (2, 0, 12.0, 0.0), + ] + + time_series = task._make_cropped_hourly_tensor( + observations=observations, + total_hours=3.0, + num_features=1, + ) + + assert len(time_series) == 3 + assert time_series[0][0] == 12.0 + + +def test_eicu_patient_generates_samples() -> None: + """Test eICU-style patient sample generation.""" + task = HourlyLOSEICU( + time_series_tables=["lab"], + time_series_features={"lab": ["creatinine"]}, + numeric_static_features=["age"], + categorical_static_features=["gender"], + min_history_hours=2, + max_hours=6, + ) + + patient_event = DummyEvent( + { + "patientunitstayid": "stay1", + "unitdischargeoffset": 240.0, + "age": 65, + "gender": "Male", + "hospitaladmittime24": "08:00:00", + } + ) + + lab_events = [ + DummyEvent( + { + "labname": "creatinine", + "labresult": 1.2, + "labresultoffset": 0.0, + } + ), + DummyEvent( + { + "labname": "creatinine", + "labresult": 1.4, + "labresultoffset": 120.0, + } + ), + ] + + patient = DummyPatient( + patient_id="p1", + tables={ + "patient": [patient_event], + "lab": lab_events, + }, + ) + + samples = task(patient) + + assert len(samples) > 0 + sample = samples[0] + + assert "time_series" in sample + assert "static" in sample + assert "target_los_hours" in sample + assert "target_los_sequence" in sample + assert isinstance(sample["time_series"], list) + assert isinstance(sample["static"], list) + assert isinstance(sample["target_los_hours"], float) + assert isinstance(sample["target_los_sequence"], list) + assert len(sample["target_los_sequence"]) == sample["history_hours"] + + +def test_mimic_patient_generates_samples() -> None: + """Test MIMIC-IV-style patient sample generation.""" + task = HourlyLOSEICU( + time_series_tables=["labevents"], + time_series_features={"labevents": ["creatinine"]}, + min_history_hours=2, + max_hours=6, + pre_icu_hours=2, + ) + + intime = datetime(2020, 1, 1, 10, 0, 0) + outtime = intime + timedelta(hours=4) + + icu_event = DummyEvent( + { + "hadm_id": "hadm1", + "stay_id": "stay1", + "outtime": outtime.isoformat(), + }, + timestamp=intime, + ) + + lab_event = DummyEvent( + { + "hadm_id": "hadm1", + "label": "creatinine", + "valuenum": 1.5, + }, + timestamp=intime + timedelta(hours=1), + ) + + patient = DummyPatient( + patient_id="p2", + tables={ + "patients": [DummyEvent({})], + "admissions": [DummyEvent({"hadm_id": "hadm1"})], + "icustays": [icu_event], + "labevents": [lab_event], + }, + ) + + samples = task(patient) + + assert len(samples) > 0 + sample = samples[0] + assert "time_series" in sample + assert "static" in sample + assert "target_los_hours" in sample + assert isinstance(sample["target_los_hours"], float) \ No newline at end of file