diff --git a/.gitignore b/.gitignore
index 9993737db..d7010da78 100644
--- a/.gitignore
+++ b/.gitignore
@@ -139,4 +139,5 @@ data/physionet.org/
.vscode/
# Model weight files (large binaries, distributed separately)
-weightfiles/
\ No newline at end of file
+weightfiles/examples/run_h1_h4_eicu.py
+examples/run_h1_h4_eicu.py
diff --git a/README_TPC_LOS.md b/README_TPC_LOS.md
new file mode 100644
index 000000000..b6d42beb7
--- /dev/null
+++ b/README_TPC_LOS.md
@@ -0,0 +1,215 @@
+# CS598 Deep Learning for Healthcare projet
+
+This PR reproduces and contributes an implementation of:
+
+**Temporal Pointwise Convolutional Networks for Length of Stay Prediction in the Intensive Care Unit
+Rocheteau et al., ACM CHIL 2021**
+Paper: https://arxiv.org/abs/2007.09483
+
+
+## Contributors
+
+* Michael Edukonis (`meduk2@illinois.edu`)
+* Keon Young Lee (`kylee7@illinois.edu`)
+* Tanmay Thareja (`tanmayt3@illinois.edu`)
+
+# PR Overview
+
+## Contribution Types
+
+This PR includes:
+
+- Model contribution
+- Standalone Task contribution
+- Synthetic data tests
+- End to End Pipeline (Model + Task + Dataset configs) example scripts + combined dual dataset run
+
+## Problem Overview
+
+
+Efficient ICU bed management depends critically on estimating how long a patient will remain in the ICU.
+
+This is formulated as:
+
+ • Input: Patient data up to hour _t_
+ • Output: Remaining ICU length of stay _(LoS)_ at time _t_
+
+We follow the formulation in the original paper, predicting remaining LoS at each hour of the ICU stay.
+
+## Implementation Details
+
+**1) Model Contribution**
+
+_pyhealth.models.tpc_
+
+We implement the Temporal Pointwise Convolution (TPC) model as a PyHealth-compatible model by extending BaseModel and follows the original paper’s architecture with adaptations for PyHealth’s input/output interfaces. Index files were updated to include the new modules accordingly.
+
+**2) Task Contribution**
+
+We implement a custom PyHealth task: Hourly Remaining Length of Stay. Index files were updated to include the new modules accordingly.
+
+_pyhealth.tasks.hourly_los_
+
+Task Definition
+
+ • Predict remaining LoS at every hour
+ • Predictions start after first 5 hours
+ • Continuous regression task
+
+Motivation
+
+ • Mimics real-world ICU monitoring
+ • Enables dynamic prediction updates
+
+**3) Ablation Study/ Example Usage**
+
+We implemented scripts for runnning the pipeline end to end with support for different experiemental setups or ablations.
+
+_examples/eicu_hourly_los_tpc.py_ : This provides an end-to-end script for reproducing and evaluating pipeline on EICU dataset.
+_examples/mimic4_hourly_los_tpc.py_ : This provides an end-to-end script for reproducing and evaluating pipeline on MIMIC-IV dataset.
+_examples/run_dual_dataset_tpc.py_ : This utility scripts runs the pipeline on both datasets and produces a combined report.
+
+**4) Test Cases**
+
+We implemented fast performing test cases for our Model and Task contribution using Sythentic Data.
+
+_tests/models/test_tpc.py_
+_tests/tasks/test_hourly_los.py_
+
+
+## Experimental Setup and Findings
+
+1) Ablations to include/exclude domain specific engineering (skip connections, decay indicators, etc)
+
+
+
+2) Comparison across using combined temporal and pointwise convolutions vs using either architecture alone.
+
+
+
+3) Feature independant (no weight-sharing) vs weight-shared temporal convolutions.
+
+
+
+4) Evaluating MSLE loss vs MSE loss for skewed LoS target regression.
+
+
+
+## Testing model with varying hyperparameters.
+
+We varied key optimization and architectural hyperparameters (e.g. learning rate, dropout rate, etc) while keeping preprocessing and data splits fixed.
+
+
+
+
+## File Structure
+
+```text
+.
+├── pyhealth/models/tpc.py
+├── pyhealth/tasks/hourly_los.py
+├── pyhealth/datasets/configs/eicu_tpc.yaml
+├── pyhealth/datasets/configs/mimic4_ehr_tpc.yaml
+├── examples/eicu_hourly_los_tpc.py
+├── examples/mimic4_hourly_los_tpc.py
+├── examples/run_dual_dataset_tpc.py
+├── tests/models/test_tpc.py
+├── tests/tasks/test_hourly_los.py
+├── docs/api/models.rst
+├── docs/api/models/pyhealth.models.tpc.rst
+├── docs/api/tasks.rst
+├── docs/api/tasks/pyhealth.tasks.hourly_los.rst
+└── README_TPC_LOS.md
+```
+
+## Setup
+
+1. Clone the repository.
+2. Create and activate a python virtual environment.
+3. Install project dependencies.
+4. Set dataset paths with environment variables for the example scripts.
+
+## Quick Start (Synthetic Data)
+
+### eICU example
+
+```bash
+EICU_ROOT=/path/to/synthetic/eicu/data \
+PYTHONPATH=. python3 examples/eicu_hourly_los_tpc.py \
+ --epochs 1 \
+ --batch_size 2 \
+ --max_samples 8 \
+ --model_variant full \
+ --loss msle \
+ --num_workers 1
+```
+
+### MIMIC-IV example
+
+```bash
+MIMIC4_ROOT=/path/to/synthetic/mimic4/data \
+PYTHONPATH=. python3 examples/mimic4_hourly_los_tpc.py \
+ --epochs 1 \
+ --batch_size 2 \
+ --max_samples 8 \
+ --loss msle \
+ --num_workers 1
+```
+
+### Combined dual-dataset run
+
+```bash
+EICU_ROOT=/path/to/synthetic/eicu/data \
+MIMIC4_ROOT=/path/to/synthetic/mimic4/data \
+PYTHONPATH=. python3 examples/run_dual_dataset_tpc.py \
+ --eicu_cache_dir /path/to/eicu/cache/location
+ --mimic_cache_dir /path/to/mimic/cache/location
+ --eicu_max_samples 8 \
+ --mimic_max_samples 8 \
+ --eicu_epochs 1 \
+ --mimic_epochs 1 \
+ --model_variant full \
+ --loss msle \
+ --num_workers 1
+```
+
+## Notes on Real Data
+
+For full eICU or MIMIC-IV experiments, point `EICU_ROOT` or `MIMIC4_ROOT` to the real dataset locations and increase settings such as `--epochs`, `--batch_size`, and `--max_samples` as needed.
+
+## Tests
+
+Run the project-specific tests with:
+
+```bash
+python3 -m pytest tests/models/test_tpc.py tests/tasks/test_hourly_los.py -q
+```
+
+## Documentation
+
+API documentation entries were added for:
+
+* `pyhealth.models.tpc`
+* `pyhealth.tasks.hourly_los`
+
+## Output
+
+The example scripts print compact summary lines for quick validation:
+
+* `ABLATION_SUMMARY` for eICU
+* `MIMIC_SUMMARY` for MIMIC-IV
+
+The dual-dataset runner parses both and prints a combined summary table.
+
+## Environment
+
+This project is designed to run within the PyHealth environment.
+
+Recommended setup:
+
+- Python 3.12
+- pyhealth (>= 2.0.0)
+- torch
+- standard scientific Python stack (numpy, pandas)
+
+Install PyHealth and dependencies following the main repository instructions.
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7c3ac7c4b..6d4932477 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -206,3 +206,4 @@ API Reference
models/pyhealth.models.BIOT
models/pyhealth.models.unified_multimodal_embedding_docs
models/pyhealth.models.califorest
+ models/pyhealth.models.tpc
diff --git a/docs/api/models/pyhealth.models.tpc.rst b/docs/api/models/pyhealth.models.tpc.rst
new file mode 100644
index 000000000..a16dfea7e
--- /dev/null
+++ b/docs/api/models/pyhealth.models.tpc.rst
@@ -0,0 +1,7 @@
+TPC Model
+=========
+
+.. automodule:: pyhealth.models.tpc
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 23a4e06e5..61ee76504 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -213,6 +213,7 @@ Available Tasks
DKA Prediction (MIMIC-IV)
Drug Recommendation
Length of Stay Prediction
+ Hourly Length-of-Stay (TPC)
Medical Transcriptions Classification
Mortality Prediction (Next Visit)
Mortality Prediction (StageNet MIMIC-IV)
diff --git a/docs/api/tasks/pyhealth.tasks.hourly_los.rst b/docs/api/tasks/pyhealth.tasks.hourly_los.rst
new file mode 100644
index 000000000..7ac100c10
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.hourly_los.rst
@@ -0,0 +1,7 @@
+ Hourly Length-of-Stay Task
+==========================
+
+.. automodule:: pyhealth.tasks.hourly_los
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/eicu_hourly_los_tpc.py b/examples/eicu_hourly_los_tpc.py
new file mode 100644
index 000000000..9fd6648b8
--- /dev/null
+++ b/examples/eicu_hourly_los_tpc.py
@@ -0,0 +1,522 @@
+"""
+eICU hourly remaining length-of-stay (LoS) training and ablation script
+for the Temporal Pointwise Convolution (TPC) model.
+
+This script runs the TPC model through the true PyHealth ``BaseModel``
+pipeline:
+
+ 1. Load an eICU base dataset with the custom YAML config.
+ 2. Convert it to a task-specific ``SampleDataset`` using ``HourlyLOSEICU``.
+ 3. Split the task dataset by patient.
+ 4. Create dataloaders with PyHealth's ``get_dataloader``.
+ 5. Instantiate ``TPC(dataset=task_dataset, ...)``.
+ 6. Train using ``pyhealth.trainer.Trainer``.
+ 7. Evaluate scalar regression loss, MAE, and RMSE.
+
+This file is intentionally designed to work with synthetic, dev, or real
+eICU roots. For project verification and fast iteration, it should be run
+against synthetic data using small sample caps.
+
+Example:
+ >>> EICU_ROOT=/path/to/synthetic/eicu_demo PYTHONPATH=. \\
+ ... python3 examples/eicu_hourly_los_tpc.py \\
+ ... --epochs 1 \\
+ ... --batch_size 2 \\
+ ... --max_samples 24
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import random
+import sys
+from typing import Dict
+
+import numpy as np
+import torch
+from pyhealth.datasets import SampleDataset
+
+DEFAULT_EICU_ROOT = os.environ.get("EICU_ROOT", "YOUR_EICU_DATASET_PATH")
+
+REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+if REPO_ROOT not in sys.path:
+ sys.path.insert(0, REPO_ROOT)
+
+from pyhealth.datasets import eICUDataset, get_dataloader, split_by_patient
+from pyhealth.models.tpc import TPC
+from pyhealth.tasks.hourly_los import HourlyLOSEICU
+from pyhealth.trainer import Trainer
+
+
+def set_seed(seed: int = 42) -> None:
+ """Set deterministic seeds for Python, NumPy, and PyTorch.
+
+ Args:
+ seed: Seed value applied to Python, NumPy, and PyTorch.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+
+
+
+def infer_model_dims(task_dataset) -> tuple[int, int]:
+ """Infer ``input_dim`` and ``static_dim`` from the processed task dataset.
+
+ Args:
+ task_dataset: PyHealth ``SampleDataset`` returned by ``set_task()``.
+
+ Returns:
+ A tuple ``(input_dim, static_dim)``.
+
+ Raises:
+ ValueError: If the first sample does not contain a valid
+ ``time_series`` or ``static`` field.
+ """
+ first_sample = task_dataset[0]
+
+ if "time_series" not in first_sample:
+ raise ValueError("Task sample is missing required field 'time_series'")
+ if "static" not in first_sample:
+ raise ValueError("Task sample is missing required field 'static'")
+
+ time_series = first_sample["time_series"]
+ static = first_sample["static"]
+
+ if not isinstance(time_series, torch.Tensor):
+ time_series = torch.tensor(time_series, dtype=torch.float32)
+ if not isinstance(static, torch.Tensor):
+ static = torch.tensor(static, dtype=torch.float32)
+
+ if time_series.ndim != 2:
+ raise ValueError(
+ "Expected task sample 'time_series' to have shape [T, 3F], got "
+ f"{tuple(time_series.shape)}"
+ )
+
+ feature_dim = time_series.shape[1]
+ if feature_dim % 3 != 0:
+ raise ValueError(
+ "Expected 'time_series' last dimension divisible by 3 for "
+ f"[value, mask, decay], got {feature_dim}"
+ )
+
+ input_dim = feature_dim // 3
+ static_dim = int(static.shape[0])
+
+ return input_dim, static_dim
+
+
+def evaluate_regression(model: TPC, dataloader) -> Dict[str, float]:
+ """Evaluate scalar regression loss, MAE, and RMSE.
+
+ Args:
+ model: Trained TPC model.
+ dataloader: Evaluation dataloader.
+
+ Returns:
+ Dictionary containing ``loss``, ``mae``, and ``rmse``.
+ """
+ model.eval()
+
+ total_loss = 0.0
+ total_abs_error = 0.0
+ total_sq_error = 0.0
+ total_count = 0
+
+ with torch.no_grad():
+ for batch in dataloader:
+ outputs = model(**batch)
+
+ batch_loss = outputs["loss"].item()
+ y_pred = outputs["y_prob"].detach().cpu().view(-1)
+ y_true = outputs["y_true"].detach().cpu().view(-1)
+
+ total_loss += batch_loss
+ total_abs_error += torch.sum(torch.abs(y_pred - y_true)).item()
+ total_sq_error += torch.sum((y_pred - y_true) ** 2).item()
+ total_count += int(y_true.numel())
+
+ mean_loss = total_loss / max(len(dataloader), 1)
+ mae = total_abs_error / max(total_count, 1)
+ rmse = (total_sq_error / max(total_count, 1)) ** 0.5
+
+ return {
+ "loss": mean_loss,
+ "mae": mae,
+ "rmse": rmse,
+ }
+
+
+def parse_args():
+ """Parse command-line arguments for the eICU TPC training script.
+
+ Returns:
+ Parsed argparse namespace.
+ """
+ parser = argparse.ArgumentParser(
+ description="Run eICU hourly remaining LoS prediction with TPC."
+ )
+ parser.add_argument(
+ "--root",
+ type=str,
+ default=DEFAULT_EICU_ROOT,
+ help=(
+ "Path to eICU dataset root. Defaults to the EICU_ROOT "
+ "environment variable when set."
+ ),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="Path to the PyHealth cache directory."
+ )
+ parser.add_argument("--dev", action="store_true", help="Use dev mode dataset")
+ parser.add_argument("--epochs", type=int, default=5)
+ parser.add_argument("--batch_size", type=int, default=8)
+ parser.add_argument(
+ "--max_samples",
+ type=int,
+ default=128,
+ help=(
+ "Approximate total sample cap across train/val/test splits for "
+ "fast synthetic or smoke-style runs."
+ ),
+ )
+ parser.add_argument("--num_layers", type=int, default=2)
+ parser.add_argument("--temporal_channels", type=int, default=4)
+ parser.add_argument("--pointwise_channels", type=int, default=4)
+ parser.add_argument("--kernel_size", type=int, default=3)
+ parser.add_argument("--fc_dim", type=int, default=16)
+ parser.add_argument("--dropout", type=float, default=0.1)
+ parser.add_argument("--lr", type=float, default=1e-3)
+ parser.add_argument(
+ "--loss",
+ type=str,
+ choices=["msle", "mse"],
+ default="msle",
+ help="Loss function for H3 ablation (default: msle).",
+ )
+ parser.add_argument(
+ "--model_variant",
+ type=str,
+ choices=["full", "temporal_only", "pointwise_only"],
+ default="full",
+ help="Architectural ablation: full TPC, temporal only, or pointwise only.",
+ )
+ parser.add_argument(
+ "--shared_temporal",
+ action="store_true",
+ help="Use shared temporal convolution weights across features.",
+ )
+ parser.add_argument(
+ "--no_skip_connections",
+ action="store_true",
+ help="Disable concatenative skip connections.",
+ )
+ parser.add_argument("--seed", type=int, default=42)
+
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="Number of CPU workers for task transformations"
+ )
+ return parser.parse_args()
+
+
+def main():
+ """Run the full eICU hourly LoS training and evaluation pipeline."""
+ args = parse_args()
+ set_seed(args.seed)
+
+ print("=" * 80)
+ print("TPC eICU Hourly LoS Run Configuration")
+ print("=" * 80)
+ print(f"root: {args.root}")
+ print(f"dev: {args.dev}")
+ print(f"epochs: {args.epochs}")
+ print(f"batch_size: {args.batch_size}")
+ print(f"num_workers: {args.num_workers}")
+ print(f"max_samples: {args.max_samples}")
+ print(f"num_layers: {args.num_layers}")
+ print(f"temporal_channels: {args.temporal_channels}")
+ print(f"pointwise_channels: {args.pointwise_channels}")
+ print(f"kernel_size: {args.kernel_size}")
+ print(f"fc_dim: {args.fc_dim}")
+ print(f"dropout: {args.dropout}")
+ print(f"lr: {args.lr}")
+ print(f"loss: {args.loss}")
+ print(f"model_variant: {args.model_variant}")
+ print(f"shared_temporal: {args.shared_temporal}")
+ print(f"no_skip_connections: {args.no_skip_connections}")
+ print(f"seed: {args.seed}")
+ print("=" * 80)
+
+ base_dataset = eICUDataset(
+ root=args.root,
+ tables=[
+ "patient",
+ "lab",
+ "respiratorycharting",
+ "nursecharting",
+ "vitalperiodic",
+ "vitalaperiodic",
+ "pasthistory",
+ "admissiondx",
+ "diagnosis",
+ ],
+ dev=args.dev,
+ cache_dir=args.cache_dir,
+ config_path=os.path.join(
+ REPO_ROOT,
+ "pyhealth/datasets/configs/eicu_tpc.yaml",
+ ),
+ )
+
+ task_dataset = base_dataset.set_task(
+ HourlyLOSEICU(
+ time_series_tables=[
+ "lab",
+ "respiratorycharting",
+ "nursecharting",
+ "vitalperiodic",
+ "vitalaperiodic",
+ ],
+ time_series_features={
+ "lab": [
+ "-basos",
+ "-eos",
+ "-lymphs",
+ "-monos",
+ "-polys",
+ "ALT (SGPT)",
+ "AST (SGOT)",
+ "BUN",
+ "Base Excess",
+ "FiO2",
+ "HCO3",
+ "Hct",
+ "Hgb",
+ "MCH",
+ "MCHC",
+ "MCV",
+ "MPV",
+ "O2 Sat (%)",
+ "PT",
+ "PT - INR",
+ "PTT",
+ "RBC",
+ "RDW",
+ "WBC x 1000",
+ "albumin",
+ "alkaline phos.",
+ "anion gap",
+ "bedside glucose",
+ "bicarbonate",
+ "calcium",
+ "chloride",
+ "creatinine",
+ "glucose",
+ "lactate",
+ "magnesium",
+ "pH",
+ "paCO2",
+ "paO2",
+ "phosphate",
+ "platelets x 1000",
+ "potassium",
+ "sodium",
+ "total bilirubin",
+ "total protein",
+ "troponin - I",
+ "urinary specific gravity",
+ ],
+ "respiratorycharting": [
+ "Exhaled MV",
+ "Exhaled TV (patient)",
+ "LPM O2",
+ "Mean Airway Pressure",
+ "Peak Insp. Pressure",
+ "PEEP",
+ "Plateau Pressure",
+ "Pressure Support",
+ "RR (patient)",
+ "SaO2",
+ "TV/kg IBW",
+ "Tidal Volume (set)",
+ "Total RR",
+ "Vent Rate",
+ ],
+ "nursecharting": [
+ "Bedside Glucose",
+ "Delirium Scale/Score",
+ "Glasgow coma score",
+ "Heart Rate",
+ "Invasive BP",
+ "Non-Invasive BP",
+ "O2 Admin Device",
+ "O2 L/%",
+ "O2 Saturation",
+ "Pain Score/Goal",
+ "Respiratory Rate",
+ "Sedation Score/Goal",
+ "Temperature",
+ ],
+ "vitalperiodic": [
+ "cvp",
+ "heartrate",
+ "respiration",
+ "sao2",
+ "st1",
+ "st2",
+ "st3",
+ "systemicdiastolic",
+ "systemicmean",
+ "systemicsystolic",
+ "temperature",
+ ],
+ "vitalaperiodic": [
+ "noninvasivediastolic",
+ "noninvasivemean",
+ "noninvasivesystolic",
+ ],
+ },
+ numeric_static_features=[
+ "age",
+ "admissionheight",
+ "admissionweight",
+ ],
+ # Keep the task contract stable for true BaseModel training.
+ categorical_static_features=[],
+ diagnosis_tables=[],
+ include_diagnoses=False,
+ diagnosis_time_limit_hours=5,
+ min_history_hours=5,
+ max_hours=48,
+ ),
+ num_workers=args.num_workers,
+ )
+
+ if len(task_dataset) == 0:
+ raise RuntimeError(
+ "No samples were generated by HourlyLOSEICU. "
+ "Check synthetic data, feature mappings, joins, cache, or dataset mode."
+ )
+
+ input_dim, static_dim = infer_model_dims(task_dataset)
+
+ print(f"task samples: {len(task_dataset)}")
+ print(f"model input_dim: {input_dim}")
+ print(f"model static_dim: {static_dim}")
+
+ # For very small synthetic datasets, use sample-level split instead of patient-level
+
+ # Use patient-level split when possible. For very small synthetic datasets,
+ # reuse one non-empty split for validation so Trainer can run.
+ num_samples = len(task_dataset)
+
+ if num_samples < 20:
+ train_ds, _, test_ds = split_by_patient(task_dataset, [0.5, 0.0, 0.5])
+ val_ds = test_ds
+ else:
+ train_ds, val_ds, test_ds = split_by_patient(task_dataset, [0.70, 0.15, 0.15])
+
+ if len(train_ds) == 0 or len(val_ds) == 0 or len(test_ds) == 0:
+ raise RuntimeError(
+ "Dataset split produced an empty train, validation, or test split."
+ )
+
+ print(f"train samples: {len(train_ds)}")
+ print(f"val samples: {len(val_ds)}")
+ print(f"test samples: {len(test_ds)}")
+
+ train_loader = get_dataloader(
+ train_ds,
+ batch_size=args.batch_size,
+ shuffle=True,
+ #num_workers=args.num_workers,
+ )
+ val_loader = get_dataloader(val_ds,
+ batch_size=args.batch_size,
+ shuffle=False,
+ #num_workers=args.num_workers,
+ )
+ test_loader = get_dataloader(test_ds,
+ batch_size=args.batch_size,
+ shuffle=False,
+ #num_workers=args.num_workers,
+ )
+ use_temporal = args.model_variant in {"full", "temporal_only"}
+ use_pointwise = args.model_variant in {"full", "pointwise_only"}
+ use_skip_connections = not args.no_skip_connections
+
+ model = TPC(
+ dataset=task_dataset,
+ input_dim=input_dim,
+ static_dim=static_dim,
+ temporal_channels=args.temporal_channels,
+ pointwise_channels=args.pointwise_channels,
+ num_layers=args.num_layers,
+ kernel_size=args.kernel_size,
+ fc_dim=args.fc_dim,
+ dropout=args.dropout,
+ shared_temporal=args.shared_temporal,
+ use_temporal=use_temporal,
+ use_pointwise=use_pointwise,
+ use_skip_connections=use_skip_connections,
+ loss_name=args.loss,
+ )
+
+ trainer = Trainer(
+ model=model,
+ device="cpu",
+ enable_logging=False,
+ )
+
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ test_dataloader=test_loader,
+ epochs=args.epochs,
+ optimizer_params={"lr": args.lr},
+ )
+
+ val_results = evaluate_regression(model, val_loader)
+ test_results = evaluate_regression(model, test_loader)
+
+ print("=" * 80)
+ print("Run complete")
+ print(f"model_variant: {args.model_variant}")
+ print(f"shared_temporal: {args.shared_temporal}")
+ print(f"no_skip_connections: {args.no_skip_connections}")
+ print(f"loss: {args.loss}")
+ print(f"val_loss: {val_results['loss']:.4f}")
+ print(f"val_mae: {val_results['mae']:.4f}")
+ print(f"val_rmse: {val_results['rmse']:.4f}")
+ print(f"test_loss: {test_results['loss']:.4f}")
+ print(f"test_mae: {test_results['mae']:.4f}")
+ print(f"test_rmse: {test_results['rmse']:.4f}")
+ print(
+ "ABLATION_SUMMARY "
+ f"model_variant={args.model_variant} "
+ f"shared_temporal={args.shared_temporal} "
+ f"no_skip_connections={args.no_skip_connections} "
+ f"loss={args.loss} "
+ f"val_loss={val_results['loss']:.4f} "
+ f"val_mae={val_results['mae']:.4f} "
+ f"val_rmse={val_results['rmse']:.4f} "
+ f"test_loss={test_results['loss']:.4f} "
+ f"test_mae={test_results['mae']:.4f} "
+ f"test_rmse={test_results['rmse']:.4f}"
+ )
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/examples/mimic4_hourly_los_tpc.py b/examples/mimic4_hourly_los_tpc.py
new file mode 100644
index 000000000..84cb66cb1
--- /dev/null
+++ b/examples/mimic4_hourly_los_tpc.py
@@ -0,0 +1,475 @@
+"""
+MIMIC-IV hourly remaining length-of-stay (LoS) training and evaluation script
+for the Temporal Pointwise Convolution (TPC) model.
+
+This script runs the TPC model through the true PyHealth ``BaseModel``
+pipeline:
+
+ 1. Load a MIMIC-IV base dataset with the custom YAML config.
+ 2. Convert it to a task-specific ``SampleDataset`` using ``HourlyLOSEICU``.
+ 3. Split the task dataset by patient.
+ 4. Create dataloaders with PyHealth's ``get_dataloader``.
+ 5. Instantiate ``TPC(dataset=task_dataset, ...)``.
+ 6. Train using ``pyhealth.trainer.Trainer``.
+ 7. Evaluate scalar regression loss, MAE, and RMSE.
+
+This file is intended to work with synthetic, dev, or real MIMIC-IV roots.
+For project verification and fast iteration, it should be run against
+synthetic data using a small sample cap.
+
+Example:
+ >>> MIMIC4_ROOT=/path/to/synthetic/mimic4_demo PYTHONPATH=. \\
+ ... python3 examples/mimic4_hourly_los_tpc.py \\
+ ... --epochs 1 \\
+ ... --batch_size 2 \\
+ ... --max_samples 24
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import random
+import sys
+from typing import Dict
+
+import numpy as np
+import torch
+
+DEFAULT_MIMIC4_ROOT = os.environ.get("MIMIC4_ROOT", "YOUR_MIMIC4_DATASET_PATH")
+
+REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+if REPO_ROOT not in sys.path:
+ sys.path.insert(0, REPO_ROOT)
+
+from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient
+from pyhealth.models.tpc import TPC
+from pyhealth.tasks.hourly_los import HourlyLOSEICU
+from pyhealth.trainer import Trainer
+
+
+def set_seed(seed: int = 42) -> None:
+ """Set deterministic seeds for Python, NumPy, and PyTorch.
+
+ Args:
+ seed: Seed value applied to Python, NumPy, and PyTorch.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+def infer_model_dims(task_dataset) -> tuple[int, int]:
+ """Infer ``input_dim`` and ``static_dim`` from the processed task dataset.
+
+ Args:
+ task_dataset: PyHealth ``SampleDataset`` returned by ``set_task()``.
+
+ Returns:
+ A tuple ``(input_dim, static_dim)``.
+
+ Raises:
+ ValueError: If the first sample does not contain a valid
+ ``time_series`` or ``static`` field.
+ """
+ first_sample = task_dataset[0]
+
+ if "time_series" not in first_sample:
+ raise ValueError("Task sample is missing required field 'time_series'")
+ if "static" not in first_sample:
+ raise ValueError("Task sample is missing required field 'static'")
+
+ time_series = first_sample["time_series"]
+ static = first_sample["static"]
+
+ if not isinstance(time_series, torch.Tensor):
+ time_series = torch.tensor(time_series, dtype=torch.float32)
+ if not isinstance(static, torch.Tensor):
+ static = torch.tensor(static, dtype=torch.float32)
+
+ if time_series.ndim != 2:
+ raise ValueError(
+ "Expected task sample 'time_series' to have shape [T, 3F], got "
+ f"{tuple(time_series.shape)}"
+ )
+
+ feature_dim = time_series.shape[1]
+ if feature_dim % 3 != 0:
+ raise ValueError(
+ "Expected 'time_series' last dimension divisible by 3 for "
+ f"[value, mask, decay], got {feature_dim}"
+ )
+
+ input_dim = feature_dim // 3
+ static_dim = int(static.shape[0])
+
+ return input_dim, static_dim
+
+
+def evaluate_regression(model: TPC, dataloader) -> Dict[str, float]:
+ """Evaluate scalar regression loss, MAE, and RMSE.
+
+ Args:
+ model: Trained TPC model.
+ dataloader: Evaluation dataloader.
+
+ Returns:
+ Dictionary containing ``loss``, ``mae``, and ``rmse``.
+ """
+ model.eval()
+
+ total_loss = 0.0
+ total_abs_error = 0.0
+ total_sq_error = 0.0
+ total_count = 0
+
+ with torch.no_grad():
+ for batch in dataloader:
+ outputs = model(**batch)
+
+ batch_loss = outputs["loss"].item()
+ y_pred = outputs["y_prob"].detach().cpu().view(-1)
+ y_true = outputs["y_true"].detach().cpu().view(-1)
+
+ total_loss += batch_loss
+ total_abs_error += torch.sum(torch.abs(y_pred - y_true)).item()
+ total_sq_error += torch.sum((y_pred - y_true) ** 2).item()
+ total_count += int(y_true.numel())
+
+ mean_loss = total_loss / max(len(dataloader), 1)
+ mae = total_abs_error / max(total_count, 1)
+ rmse = (total_sq_error / max(total_count, 1)) ** 0.5
+
+ return {
+ "loss": mean_loss,
+ "mae": mae,
+ "rmse": rmse,
+ }
+
+
+def parse_args():
+ """Parse command-line arguments for the MIMIC-IV TPC training script.
+
+ Returns:
+ Parsed argparse namespace.
+ """
+ parser = argparse.ArgumentParser(
+ description="Run MIMIC-IV hourly remaining LoS prediction with TPC."
+ )
+ parser.add_argument(
+ "--root",
+ type=str,
+ default=DEFAULT_MIMIC4_ROOT,
+ help=(
+ "Path to the MIMIC-IV dataset root (directory containing hosp/ and "
+ "icu/). Defaults to the MIMIC4_ROOT environment variable when set."
+ ),
+ )
+ parser.add_argument("--dev", action="store_true", help="Use dev mode dataset")
+ parser.add_argument("--epochs", type=int, default=5)
+ parser.add_argument("--batch_size", type=int, default=8)
+ parser.add_argument(
+ "--max_samples",
+ type=int,
+ default=128,
+ help=(
+ "Approximate total sample cap across train/val/test splits for "
+ "fast synthetic or smoke-style runs."
+ ),
+ )
+ parser.add_argument("--num_layers", type=int, default=2)
+ parser.add_argument("--temporal_channels", type=int, default=4)
+ parser.add_argument("--pointwise_channels", type=int, default=4)
+ parser.add_argument("--kernel_size", type=int, default=3)
+ parser.add_argument("--fc_dim", type=int, default=16)
+ parser.add_argument("--dropout", type=float, default=0.1)
+ parser.add_argument("--lr", type=float, default=1e-3)
+ parser.add_argument(
+ "--loss",
+ type=str,
+ choices=["msle", "mse"],
+ default="msle",
+ help="Loss function for H3 comparison (default: msle).",
+ )
+ parser.add_argument(
+ "--max_hours",
+ type=int,
+ default=336,
+ help="Maximum ICU hours to model (14 days = 336 hours).",
+ )
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="Number of CPU workers for task transformations"
+)
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="Path to the PyHealth cache directory.",
+)
+ return parser.parse_args()
+
+
+def main():
+ """Run the full MIMIC-IV hourly LoS training and evaluation pipeline."""
+ args = parse_args()
+ set_seed(args.seed)
+
+ print("=" * 80)
+ print("TPC MIMIC-IV Hourly LoS Run Configuration")
+ print("=" * 80)
+ print(f"root: {args.root}")
+ print(f"dev: {args.dev}")
+ print(f"epochs: {args.epochs}")
+ print(f"batch_size: {args.batch_size}")
+ print(f"num_workers: {args.num_workers}")
+ print(f"max_samples: {args.max_samples}")
+ print(f"num_layers: {args.num_layers}")
+ print(f"temporal_channels: {args.temporal_channels}")
+ print(f"pointwise_channels: {args.pointwise_channels}")
+ print(f"kernel_size: {args.kernel_size}")
+ print(f"fc_dim: {args.fc_dim}")
+ print(f"dropout: {args.dropout}")
+ print(f"lr: {args.lr}")
+ print(f"loss: {args.loss}")
+ print(f"max_hours: {args.max_hours}")
+ print(f"seed: {args.seed}")
+ print("=" * 80)
+
+ base_dataset = MIMIC4Dataset(
+ ehr_root=args.root,
+ ehr_tables=[
+ "patients",
+ "admissions",
+ "icustays",
+ "chartevents",
+ "labevents",
+ ],
+ ehr_config_path=os.path.join(
+ REPO_ROOT,
+ "pyhealth/datasets/configs/mimic4_ehr_tpc.yaml",
+ ),
+ dev=args.dev,
+ cache_dir=args.cache_dir, # Added proper cache directory support
+ )
+
+ task_dataset = base_dataset.set_task(
+ HourlyLOSEICU(
+ time_series_tables=["chartevents", "labevents"],
+ time_series_features={
+ "chartevents": [
+ "Activity / Mobility (JH-HLM)",
+ "Apnea Interval",
+ "Arterial Blood Pressure Alarm - High",
+ "Arterial Blood Pressure Alarm - Low",
+ "Arterial Blood Pressure diastolic",
+ "Arterial Blood Pressure mean",
+ "Arterial Blood Pressure systolic",
+ "Braden Score",
+ "Current Dyspnea Assessment",
+ "Daily Weight",
+ "Expiratory Ratio",
+ "Fspn High",
+ "GCS - Eye Opening",
+ "GCS - Motor Response",
+ "GCS - Verbal Response",
+ "Glucose finger stick (range 70-100)",
+ "Heart Rate",
+ "Heart Rate Alarm - Low",
+ "Heart rate Alarm - High",
+ "Inspired O2 Fraction",
+ "Mean Airway Pressure",
+ "Minute Volume",
+ "Minute Volume Alarm - High",
+ "Minute Volume Alarm - Low",
+ "Non Invasive Blood Pressure diastolic",
+ "Non Invasive Blood Pressure mean",
+ "Non Invasive Blood Pressure systolic",
+ "Non-Invasive Blood Pressure Alarm - High",
+ "Non-Invasive Blood Pressure Alarm - Low",
+ "O2 Flow",
+ "O2 Saturation Pulseoxymetry Alarm - Low",
+ "O2 saturation pulseoxymetry",
+ "PEEP set",
+ "PSV Level",
+ "Pain Level",
+ "Pain Level Response",
+ "Paw High",
+ "Peak Insp. Pressure",
+ "Phosphorous",
+ "Plateau Pressure",
+ "Resp Alarm - High",
+ "Resp Alarm - Low",
+ "Respiratory Rate",
+ "Respiratory Rate (Set)",
+ "Respiratory Rate (Total)",
+ "Respiratory Rate (spontaneous)",
+ "Richmond-RAS Scale",
+ "Strength L Arm",
+ "Strength L Leg",
+ "Strength R Arm",
+ "Strength R Leg",
+ "Temperature Fahrenheit",
+ "Tidal Volume (observed)",
+ "Tidal Volume (set)",
+ "Tidal Volume (spontaneous)",
+ "Total PEEP Level",
+ "Ventilator Mode",
+ "Vti High",
+ ],
+ "labevents": [
+ "Alanine Aminotransferase (ALT)",
+ "Alkaline Phosphatase",
+ "Anion Gap",
+ "Asparate Aminotransferase (AST)",
+ "Base Excess",
+ "Bicarbonate",
+ "Bilirubin, Total",
+ "Calcium, Total",
+ "Calculated Total CO2",
+ "Chloride",
+ "Creatinine",
+ "Free Calcium",
+ "Glucose",
+ "Hematocrit",
+ "Hematocrit, Calculated",
+ "Hemoglobin",
+ "INR(PT)",
+ "Lactate",
+ "MCH",
+ "MCHC",
+ "MCV",
+ "Magnesium",
+ "Oxygen Saturation",
+ "PT",
+ "PTT",
+ "Phosphate",
+ "Platelet Count",
+ "Potassium",
+ "Potassium, Whole Blood",
+ "RDW",
+ "RDW-SD",
+ "Red Blood Cells",
+ "Sodium",
+ "Sodium, Whole Blood",
+ "Temperature",
+ "Urea Nitrogen",
+ "White Blood Cells",
+ "pCO2",
+ "pH",
+ "pO2",
+ ],
+ },
+ numeric_static_features=[],
+ categorical_static_features=[],
+ min_history_hours=5,
+ max_hours=args.max_hours,
+ ),
+ num_workers=args.num_workers,
+ )
+
+ if len(task_dataset) == 0:
+ raise RuntimeError(
+ "No samples were generated by HourlyLOSEICU. "
+ "Check synthetic data, feature mappings, joins, cache, or dataset mode."
+ )
+
+ input_dim, static_dim = infer_model_dims(task_dataset)
+
+ print(f"task samples: {len(task_dataset)}")
+ print(f"model input_dim: {input_dim}")
+ print(f"model static_dim: {static_dim}")
+
+ num_samples = len(task_dataset)
+
+ if num_samples < 20:
+ train_ds, _, test_ds = split_by_patient(task_dataset, [0.5, 0.0, 0.5])
+ val_ds = test_ds
+ else:
+ train_ds, val_ds, test_ds = split_by_patient(task_dataset, [0.70, 0.15, 0.15])
+
+ if len(train_ds) == 0 or len(val_ds) == 0 or len(test_ds) == 0:
+ raise RuntimeError(
+ "Dataset split produced an empty train, validation, or test split."
+ )
+
+ print(f"train samples: {len(train_ds)}")
+ print(f"val samples: {len(val_ds)}")
+ print(f"test samples: {len(test_ds)}")
+
+
+ train_loader = get_dataloader(
+ train_ds,
+ batch_size=args.batch_size,
+ shuffle=True,
+ #num_workers=dl_workers,
+ )
+ val_loader = get_dataloader(val_ds,
+ batch_size=args.batch_size,
+ shuffle=False,
+ #num_workers=dl_workers,
+ )
+ test_loader = get_dataloader(test_ds,
+ batch_size=args.batch_size,
+ shuffle=False,
+ #num_workers=dl_workers,
+ )
+
+ model = TPC(
+ dataset=task_dataset,
+ input_dim=input_dim,
+ static_dim=static_dim,
+ temporal_channels=args.temporal_channels,
+ pointwise_channels=args.pointwise_channels,
+ num_layers=args.num_layers,
+ kernel_size=args.kernel_size,
+ fc_dim=args.fc_dim,
+ dropout=args.dropout,
+ loss_name=args.loss,
+ )
+
+
+ trainer = Trainer(
+ model=model,
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ enable_logging=False,
+ )
+
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ test_dataloader=test_loader,
+ epochs=args.epochs,
+ optimizer_params={"lr": args.lr},
+ )
+
+ val_results = evaluate_regression(model, val_loader)
+ test_results = evaluate_regression(model, test_loader)
+
+ print("=" * 80)
+ print("Run complete")
+ print(f"loss: {args.loss}")
+ print(f"val_loss: {val_results['loss']:.4f}")
+ print(f"val_mae: {val_results['mae']:.4f}")
+ print(f"val_rmse: {val_results['rmse']:.4f}")
+ print(f"test_loss: {test_results['loss']:.4f}")
+ print(f"test_mae: {test_results['mae']:.4f}")
+ print(f"test_rmse: {test_results['rmse']:.4f}")
+ print(
+ "MIMIC_SUMMARY "
+ f"loss={args.loss} "
+ f"val_loss={val_results['loss']:.4f} "
+ f"val_mae={val_results['mae']:.4f} "
+ f"val_rmse={val_results['rmse']:.4f} "
+ f"test_loss={test_results['loss']:.4f} "
+ f"test_mae={test_results['mae']:.4f} "
+ f"test_rmse={test_results['rmse']:.4f}"
+ )
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/examples/run_dual_dataset_tpc.py b/examples/run_dual_dataset_tpc.py
new file mode 100644
index 000000000..37f00009f
--- /dev/null
+++ b/examples/run_dual_dataset_tpc.py
@@ -0,0 +1,338 @@
+"""
+Run eICU and MIMIC-IV TPC example pipelines sequentially and summarize results.
+
+This utility script orchestrates back-to-back execution of the eICU and
+MIMIC-IV hourly length-of-stay (LoS) Temporal Pointwise Convolution (TPC)
+example scripts using the current BaseModel-compatible example interfaces.
+
+Overview:
+ The script performs the following steps:
+
+ 1. Build subprocess commands for the eICU and MIMIC-IV example scripts.
+ 2. Run each script sequentially from the repository root.
+ 3. Stream stdout to the console in real time.
+ 4. Parse the emitted summary lines from each script.
+ 5. Print a combined summary table for quick comparison.
+
+Inputs:
+ Command-line arguments controlling:
+ - optional dataset roots
+ - dev mode
+ - batch size
+ - epochs
+ - sample counts
+ - whether either dataset run should be skipped
+
+Outputs:
+ - Real-time console output from each child script
+ - Parsed eICU and MIMIC-IV summary metrics
+ - Combined summary table printed to stdout
+ - Nonzero exit if one or more subprocess runs fail
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import re
+import subprocess
+import sys
+
+
+REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+if REPO_ROOT not in sys.path:
+ sys.path.insert(0, REPO_ROOT)
+
+
+SUMMARY_PATTERNS = {
+ "eicu": re.compile(r"^ABLATION_SUMMARY\s+(.*)$"),
+ "mimic": re.compile(r"^MIMIC_SUMMARY\s+(.*)$"),
+}
+
+
+def parse_summary_fields(summary_body: str) -> Dict[str, str]:
+ """Parse whitespace-delimited key-value tokens from a summary line."""
+ fields: Dict[str, str] = {}
+ for token in summary_body.strip().split():
+ if "=" in token:
+ key, value = token.split("=", 1)
+ fields[key.strip()] = value.strip()
+ return fields
+
+
+def run_command(
+ label: str,
+ cmd: List[str],
+ cwd: str,
+) -> Tuple[int, List[str], Optional[str]]:
+ """Run a subprocess command and capture its printed summary line."""
+ print("=" * 80)
+ print(f"Running {label}")
+ print("=" * 80)
+ print("Command:")
+ print(" ".join(cmd))
+ print("-" * 80)
+
+ process = subprocess.Popen(
+ cmd,
+ cwd=cwd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ )
+
+ output_lines: List[str] = []
+ summary_line: Optional[str] = None
+
+ assert process.stdout is not None
+ for line in process.stdout:
+ print(line, end="")
+ output_lines.append(line.rstrip("\n"))
+
+ pattern = SUMMARY_PATTERNS.get(label)
+ if pattern is not None:
+ match = pattern.match(line.strip())
+ if match:
+ summary_line = match.group(1)
+
+ return_code = process.wait()
+
+ print("-" * 80)
+ print(f"{label} return code: {return_code}")
+ print("=" * 80)
+
+ return return_code, output_lines, summary_line
+
+
+def format_value(fields: Dict[str, str], key: str) -> str:
+ """Return a formatted field value or ``NA`` if the key is missing."""
+ return fields.get(key, "NA")
+
+
+def print_combined_summary(
+ eicu_fields: Optional[Dict[str, str]],
+ mimic_fields: Optional[Dict[str, str]],
+) -> None:
+ """Print a combined cross-dataset summary report."""
+ print("\n" + "=" * 80)
+ print("COMBINED SUMMARY")
+ print("=" * 80)
+
+ if eicu_fields is None:
+ print("eICU summary: NOT FOUND")
+ else:
+ print("eICU summary:")
+ print(
+ " "
+ f"model_variant={format_value(eicu_fields, 'model_variant')} "
+ f"shared_temporal={format_value(eicu_fields, 'shared_temporal')} "
+ f"no_skip_connections={format_value(eicu_fields, 'no_skip_connections')} "
+ f"loss={format_value(eicu_fields, 'loss')} "
+ f"val_loss={format_value(eicu_fields, 'val_loss')} "
+ f"val_mae={format_value(eicu_fields, 'val_mae')} "
+ f"val_rmse={format_value(eicu_fields, 'val_rmse')} "
+ f"test_loss={format_value(eicu_fields, 'test_loss')} "
+ f"test_mae={format_value(eicu_fields, 'test_mae')} "
+ f"test_rmse={format_value(eicu_fields, 'test_rmse')}"
+ )
+
+ if mimic_fields is None:
+ print("MIMIC-IV summary: NOT FOUND")
+ else:
+ print("MIMIC-IV summary:")
+ print(
+ " "
+ f"loss={format_value(mimic_fields, 'loss')} "
+ f"val_loss={format_value(mimic_fields, 'val_loss')} "
+ f"val_mae={format_value(mimic_fields, 'val_mae')} "
+ f"val_rmse={format_value(mimic_fields, 'val_rmse')} "
+ f"test_loss={format_value(mimic_fields, 'test_loss')} "
+ f"test_mae={format_value(mimic_fields, 'test_mae')} "
+ f"test_rmse={format_value(mimic_fields, 'test_rmse')}"
+ )
+
+ if eicu_fields is not None and mimic_fields is not None:
+ print("\nCompact table:")
+ print(
+ f"{'dataset':<10} {'loss':<8} {'val_loss':<10} {'val_mae':<10} "
+ f"{'val_rmse':<10} {'test_loss':<10} {'test_mae':<10} {'test_rmse':<10}"
+ )
+ print(
+ f"{'eICU':<10} "
+ f"{format_value(eicu_fields, 'loss'):<8} "
+ f"{format_value(eicu_fields, 'val_loss'):<10} "
+ f"{format_value(eicu_fields, 'val_mae'):<10} "
+ f"{format_value(eicu_fields, 'val_rmse'):<10} "
+ f"{format_value(eicu_fields, 'test_loss'):<10} "
+ f"{format_value(eicu_fields, 'test_mae'):<10} "
+ f"{format_value(eicu_fields, 'test_rmse'):<10}"
+ )
+ print(
+ f"{'MIMIC-IV':<10} "
+ f"{format_value(mimic_fields, 'loss'):<8} "
+ f"{format_value(mimic_fields, 'val_loss'):<10} "
+ f"{format_value(mimic_fields, 'val_mae'):<10} "
+ f"{format_value(mimic_fields, 'val_rmse'):<10} "
+ f"{format_value(mimic_fields, 'test_loss'):<10} "
+ f"{format_value(mimic_fields, 'test_mae'):<10} "
+ f"{format_value(mimic_fields, 'test_rmse'):<10}"
+ )
+
+ print("=" * 80)
+
+
+def build_eicu_command(args) -> List[str]:
+ """Build the subprocess command for the eICU example script."""
+ cmd = [
+ sys.executable,
+ "examples/eicu_hourly_los_tpc.py",
+ "--epochs",
+ str(args.eicu_epochs),
+ "--batch_size",
+ str(args.batch_size),
+ "--max_samples",
+ str(args.eicu_max_samples),
+ "--loss",
+ str(args.eicu_loss),
+ # Pass the parallel worker count to leverage the i9
+ "--num_workers",
+ str(args.num_workers),
+ ]
+
+ if args.eicu_root:
+ cmd.extend(["--root", args.eicu_root])
+
+ # Pass the external cache directory to avoid re-processing
+ if hasattr(args, 'eicu_cache_dir') and args.eicu_cache_dir:
+ cmd.extend(["--cache_dir", args.eicu_cache_dir])
+
+ if args.dev:
+ cmd.append("--dev")
+
+ if args.eicu_model_variant:
+ cmd.extend(["--model_variant", args.eicu_model_variant])
+
+ if args.eicu_shared_temporal:
+ cmd.append("--shared_temporal")
+
+ if args.eicu_no_skip_connections:
+ cmd.append("--no_skip_connections")
+
+ return cmd
+
+
+def build_mimic_command(args) -> List[str]:
+ """Build the subprocess command for the MIMIC-IV example script."""
+ cmd = [
+ sys.executable,
+ "examples/mimic4_hourly_los_tpc.py",
+ "--epochs", str(args.mimic_epochs),
+ "--batch_size", str(args.batch_size),
+ "--max_samples", str(args.mimic_max_samples),
+ "--loss", str(args.mimic_loss),
+ "--num_workers", str(args.num_workers),
+ ]
+
+ if args.mimic_root:
+ cmd.extend(["--root", args.mimic_root])
+
+ # Pass the MIMIC-IV cache directory
+ if hasattr(args, 'mimic_cache_dir') and args.mimic_cache_dir:
+ cmd.extend(["--cache_dir", args.mimic_cache_dir])
+
+ if args.dev:
+ cmd.append("--dev")
+
+ return cmd
+
+
+def parse_args():
+ """Parse command-line arguments for the dual-dataset run script."""
+ parser = argparse.ArgumentParser(
+ description=(
+ "Run eICU and MIMIC-IV TPC examples sequentially and print a "
+ "combined summary."
+ )
+ )
+ parser.add_argument("--dev", action="store_true", help="Run both scripts in dev mode")
+ parser.add_argument("--batch_size", type=int, default=2)
+ parser.add_argument("--eicu_epochs", type=int, default=1)
+ parser.add_argument("--mimic_epochs", type=int, default=1)
+ parser.add_argument("--eicu_max_samples", type=int, default=24)
+ parser.add_argument("--mimic_max_samples", type=int, default=24)
+ parser.add_argument("--eicu_root", type=str, default="")
+ parser.add_argument("--mimic_root", type=str, default="")
+ parser.add_argument(
+ "--eicu_loss",
+ type=str,
+ choices=["msle", "mse"],
+ default="msle",
+ help="Loss to pass to the eICU script.",
+ )
+ parser.add_argument(
+ "--mimic_loss",
+ type=str,
+ choices=["msle", "mse"],
+ default="msle",
+ help="Loss to pass to the MIMIC-IV script.",
+ )
+ parser.add_argument(
+ "--eicu_model_variant",
+ type=str,
+ choices=["full", "temporal_only", "pointwise_only"],
+ default="full",
+ )
+ parser.add_argument(
+ "--eicu_shared_temporal",
+ action="store_true",
+ help="Pass --shared_temporal to the eICU script.",
+ )
+ parser.add_argument(
+ "--eicu_no_skip_connections",
+ action="store_true",
+ help="Pass --no_skip_connections to the eICU script.",
+ )
+ parser.add_argument("--skip_eicu", action="store_true", help="Skip the eICU run")
+ parser.add_argument("--skip_mimic", action="store_true", help="Skip the MIMIC-IV run")
+ parser.add_argument("--num_workers", type=int, default=1, help="Parallel workers for both datasets.")
+ parser.add_argument("--eicu_cache_dir", type=str, default=None, help="Cache path for eICU.")
+ parser.add_argument("--mimic_cache_dir", type=str, default=None, help="Cache path for MIMIC-IV.")
+ return parser.parse_args()
+
+
+def main():
+ """Run the configured eICU and MIMIC-IV example scripts and summarize results."""
+ args = parse_args()
+
+ eicu_fields: Optional[Dict[str, str]] = None
+ mimic_fields: Optional[Dict[str, str]] = None
+
+ eicu_rc = 0
+ mimic_rc = 0
+
+ if not args.skip_eicu:
+ eicu_cmd = build_eicu_command(args)
+ eicu_rc, _, eicu_summary = run_command("eicu", eicu_cmd, cwd=REPO_ROOT)
+ if eicu_summary is not None:
+ eicu_fields = parse_summary_fields(eicu_summary)
+
+ if not args.skip_mimic:
+ mimic_cmd = build_mimic_command(args)
+ mimic_rc, _, mimic_summary = run_command("mimic", mimic_cmd, cwd=REPO_ROOT)
+ if mimic_summary is not None:
+ mimic_fields = parse_summary_fields(mimic_summary)
+
+ print_combined_summary(eicu_fields, mimic_fields)
+
+ if eicu_rc != 0 or mimic_rc != 0:
+ failed = []
+ if eicu_rc != 0:
+ failed.append(f"eICU rc={eicu_rc}")
+ if mimic_rc != 0:
+ failed.append(f"MIMIC-IV rc={mimic_rc}")
+ raise SystemExit("One or more runs failed: " + ", ".join(failed))
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/pyhealth/datasets/configs/eicu_tpc.yaml b/pyhealth/datasets/configs/eicu_tpc.yaml
new file mode 100644
index 000000000..d0dba8b06
--- /dev/null
+++ b/pyhealth/datasets/configs/eicu_tpc.yaml
@@ -0,0 +1,362 @@
+# Custom eICU dataset configuration for the TPC hourly length-of-stay project.
+#
+# This configuration extends the standard PyHealth eICU dataset mapping to
+# support the Temporal Pointwise Convolution (TPC) replication pipeline.
+# It exposes the tables and attributes needed for hourly ICU remaining
+# length-of-stay prediction, including time-series measurements, static
+# patient features, and diagnosis/history sources used by the custom task.
+#
+# Intended use:
+# - pyhealth.tasks.hourly_los.HourlyLOSEICU
+# - examples/eicu_hourly_los_tpc.py
+#
+# Notes:
+# - Table names are normalized to the lowercase names used by the custom
+# task pipeline.
+# - Several tables join through patient.csv on patientunitstayid in order
+# to recover uniquepid and visit-level context for task construction.
+# - This file is project-specific and should be kept aligned with the
+# expected table/attribute names used in the eICU TPC task code.
+version: "2.0"
+tables:
+ patient:
+ file_path: "patient.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ attributes:
+ - "patientunitstayid"
+ - "patienthealthsystemstayid"
+ - "gender"
+ - "age"
+ - "ethnicity"
+ - "hospitalid"
+ - "wardid"
+ - "apacheadmissiondx"
+ - "admissionheight"
+ - "admissionweight"
+ - "dischargeweight"
+ - "hospitaladmittime24"
+ - "hospitaladmitoffset"
+ - "hospitaladmitsource"
+ - "hospitaldischargeyear"
+ - "hospitaldischargeoffset"
+ - "hospitaldischargestatus"
+ - "hospitaldischargetime24"
+ - "unittype"
+ - "unitadmittime24"
+ - "unitadmitsource"
+ - "unitvisitnumber"
+ - "unitstaytype"
+ - "unitdischargeoffset"
+ - "unitdischargestatus"
+ - "unitdischargetime24"
+ - "unitdischargelocation"
+
+ hospital:
+ file_path: "hospital.csv"
+ patient_id: null
+ timestamp: null
+ attributes:
+ - "hospitalid"
+ - "numbedscategory"
+ - "teachingstatus"
+ - "region"
+
+ diagnosis:
+ file_path: "diagnosis.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "diagnosisoffset"
+ - "diagnosisstring"
+ - "icd9code"
+ - "diagnosispriority"
+
+ medication:
+ file_path: "medication.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "drugstartoffset"
+ - "drugstopoffset"
+ - "drugname"
+ - "drughiclseqno"
+ - "dosage"
+ - "routeadmin"
+ - "frequency"
+ - "loadingdose"
+ - "prn"
+ - "drugordercancelled"
+
+ treatment:
+ file_path: "treatment.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "treatmentoffset"
+ - "treatmentstring"
+ - "activeupondischarge"
+
+ lab:
+ file_path: "lab.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "labresultoffset"
+ - "labname"
+ - "labresult"
+ - "labresulttext"
+ - "labmeasurenamesystem"
+ - "labmeasurenameinterface"
+ - "labtypeid"
+
+ physicalexam:
+ file_path: "physicalExam.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "physicalexamoffset"
+ - "physicalexampath"
+ - "physicalexamtext"
+ - "physicalexamvalue"
+
+ admissiondx:
+ file_path: "admissionDx.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "admitdxenteredoffset"
+ - "admitdxpath"
+ - "admitdxname"
+ - "admitdxtext"
+
+ respiratorycharting:
+ file_path: "respiratoryCharting.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "respchartoffset"
+ - "respchartentryoffset"
+ - "respcharttypecat"
+ - "respchartvaluelabel"
+ - "respchartvalue"
+
+ nursecharting:
+ file_path: "nurseCharting.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "nursingchartoffset"
+ - "nursingchartentryoffset"
+ - "nursingchartcelltypecat"
+ - "nursingchartcelltypevallabel"
+ - "nursingchartcelltypevalname"
+ - "nursingchartvalue"
+
+ vitalperiodic:
+ file_path: "vitalPeriodic.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "observationoffset"
+ - "temperature"
+ - "sao2"
+ - "heartrate"
+ - "respiration"
+ - "cvp"
+ - "etco2"
+ - "systemicsystolic"
+ - "systemicdiastolic"
+ - "systemicmean"
+ - "pasystolic"
+ - "padiastolic"
+ - "pamean"
+ - "st1"
+ - "st2"
+ - "st3"
+ - "icp"
+
+ vitalaperiodic:
+ file_path: "vitalAperiodic.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "observationoffset"
+ - "noninvasivesystolic"
+ - "noninvasivediastolic"
+ - "noninvasivemean"
+ - "paop"
+ - "cardiacoutput"
+ - "cardiacinput"
+ - "svr"
+ - "svri"
+ - "pvr"
+ - "pvri"
+
+ pasthistory:
+ file_path: "pastHistory.csv"
+ patient_id: "uniquepid"
+ timestamp: null
+ join:
+ - file_path: "patient.csv"
+ "on": "patientunitstayid"
+ how: "inner"
+ columns:
+ - "uniquepid"
+ - "patienthealthsystemstayid"
+ - "unitvisitnumber"
+ - "hospitaladmitoffset"
+ - "hospitaldischargeoffset"
+ - "unitdischargeoffset"
+ - "hospitaldischargeyear"
+ - "hospitaldischargestatus"
+ attributes:
+ - "patientunitstayid"
+ - "pasthistoryoffset"
+ - "pasthistoryenteredoffset"
+ - "pasthistorynotetype"
+ - "pasthistorypath"
+ - "pasthistoryvalue"
+ - "pasthistoryvaluetext"
+
+
diff --git a/pyhealth/datasets/configs/mimic4_ehr_tpc.yaml b/pyhealth/datasets/configs/mimic4_ehr_tpc.yaml
new file mode 100644
index 000000000..7bc0a1ea0
--- /dev/null
+++ b/pyhealth/datasets/configs/mimic4_ehr_tpc.yaml
@@ -0,0 +1,173 @@
+# Custom MIMIC-IV dataset configuration for the TPC hourly length-of-stay project.
+#
+# This configuration defines the table mappings and joins required to load
+# MIMIC-IV data into PyHealth for the Temporal Pointwise Convolution (TPC)
+# replication pipeline. It is tailored to support hourly ICU remaining
+# length-of-stay prediction using time-series EHR data.
+#
+# Intended use:
+# - pyhealth.tasks.hourly_los.HourlyLOSEICU (MIMIC-IV variant)
+# - examples/mimic4_hourly_los_tpc.py
+#
+# Notes:
+# - Core ICU time-series data is sourced from `chartevents`, joined with
+# `d_items` to obtain human-readable labels and feature metadata.
+# - Laboratory features are sourced from `labevents`, joined with
+# `d_labitems` for standardized naming and categorization.
+# - Static and contextual patient information is derived from `patients`,
+# `admissions`, and `icustays`.
+# - Diagnosis and procedure codes are included for optional feature
+# engineering and cohort analysis.
+# - Timestamps are defined per table to enable temporal alignment and
+# hourly resampling within the task pipeline.
+# - This configuration is project-specific and should remain aligned with
+# the feature extraction logic implemented in the MIMIC-IV LoS task.
+version: "2.2"
+tables:
+ patients:
+ file_path: "hosp/patients.csv.gz"
+ patient_id: "subject_id"
+ timestamp: null
+ attributes:
+ - "gender"
+ - "anchor_age"
+ - "anchor_year"
+ - "anchor_year_group"
+ - "dod"
+
+ chartevents:
+ file_path: "icu/chartevents.csv.gz"
+ patient_id: "subject_id"
+ join:
+ - file_path: "icu/d_items.csv.gz"
+ "on": "itemid"
+ how: "inner"
+ columns:
+ - "label"
+ - "abbreviation"
+ - "linksto"
+ - "category"
+ - "unitname"
+ - "param_type"
+ timestamp: "charttime"
+ attributes:
+ - "hadm_id"
+ - "stay_id"
+ - "itemid"
+ - "label"
+ - "abbreviation"
+ - "linksto"
+ - "category"
+ - "unitname"
+ - "param_type"
+ - "value"
+ - "valuenum"
+ - "valueuom"
+ - "storetime"
+ - "warning"
+
+ admissions:
+ file_path: "hosp/admissions.csv.gz"
+ patient_id: "subject_id"
+ timestamp: "admittime"
+ attributes:
+ - "hadm_id"
+ - "admission_type"
+ - "admission_location"
+ - "insurance"
+ - "language"
+ - "marital_status"
+ - "race"
+ - "discharge_location"
+ - "dischtime"
+ - "hospital_expire_flag"
+
+ icustays:
+ file_path: "icu/icustays.csv.gz"
+ patient_id: "subject_id"
+ timestamp: "intime"
+ attributes:
+ - "hadm_id"
+ - "stay_id"
+ - "first_careunit"
+ - "last_careunit"
+ - "outtime"
+
+ diagnoses_icd:
+ file_path: "hosp/diagnoses_icd.csv.gz"
+ patient_id: "subject_id"
+ join:
+ - file_path: "hosp/admissions.csv.gz"
+ "on": "hadm_id"
+ how: "inner"
+ columns:
+ - "dischtime"
+ timestamp: "dischtime"
+ attributes:
+ - "hadm_id"
+ - "icd_code"
+ - "icd_version"
+ - "seq_num"
+
+ procedures_icd:
+ file_path: "hosp/procedures_icd.csv.gz"
+ patient_id: "subject_id"
+ join:
+ - file_path: "hosp/admissions.csv.gz"
+ "on": "hadm_id"
+ how: "inner"
+ columns:
+ - "dischtime"
+ timestamp: "dischtime"
+ attributes:
+ - "hadm_id"
+ - "icd_code"
+ - "icd_version"
+ - "seq_num"
+
+ prescriptions:
+ file_path: "hosp/prescriptions.csv.gz"
+ patient_id: "subject_id"
+ timestamp: "starttime"
+ attributes:
+ - "hadm_id"
+ - "drug"
+ - "ndc"
+ - "prod_strength"
+ - "dose_val_rx"
+ - "dose_unit_rx"
+ - "route"
+ - "stoptime"
+
+ labevents:
+ file_path: "hosp/labevents.csv.gz"
+ patient_id: "subject_id"
+ join:
+ - file_path: "hosp/d_labitems.csv.gz"
+ "on": "itemid"
+ how: "inner"
+ columns:
+ - "label"
+ - "fluid"
+ - "category"
+ timestamp: "charttime"
+ attributes:
+ - "hadm_id"
+ - "itemid"
+ - "label"
+ - "fluid"
+ - "category"
+ - "value"
+ - "valuenum"
+ - "valueuom"
+ - "flag"
+ - "storetime"
+
+ hcpcsevents:
+ file_path: "hosp/hcpcsevents.csv.gz"
+ patient_id: "subject_id"
+ timestamp: "chartdate"
+ attributes:
+ - "hcpcs_cd"
+ - "seq_num"
+ - "short_description"
\ No newline at end of file
diff --git a/pyhealth/models/tpc.py b/pyhealth/models/tpc.py
new file mode 100644
index 000000000..547c3b255
--- /dev/null
+++ b/pyhealth/models/tpc.py
@@ -0,0 +1,810 @@
+"""
+Temporal Pointwise Convolution (TPC) model for hourly ICU remaining
+length-of-stay (LoS) regression in PyHealth.
+
+This module provides a dataset-backed PyHealth ``BaseModel`` implementation
+of the Temporal Pointwise Convolution (TPC) architecture described in:
+
+Rocheteau, E., Liò, P., and Hyland, S. (2021).
+"Temporal Pointwise Convolutional Networks for Length of Stay Prediction
+in the Intensive Care Unit."
+
+Overview:
+ TPC is designed for multivariate, irregularly sampled EHR time series.
+ It combines two complementary operations at each layer:
+
+ 1. Temporal convolution:
+ Feature-wise or shared causal convolutions over time, allowing
+ each clinical variable to learn temporal dynamics.
+
+ 2. Pointwise convolution:
+ Per-time-step feature mixing to capture cross-feature interactions
+ without temporal leakage.
+
+ In this PyHealth implementation, the task supplies:
+ - ``time_series``: per-sample hourly history encoded as [T, 3F]
+ in [value, mask, decay] order for each feature
+ - ``static``: optional static feature vector
+ - ``target_los_hours``: scalar regression target
+
+ The model consumes task-processed batch fields via ``forward(**kwargs)``
+ and returns the standard PyHealth output dictionary:
+ - ``loss``
+ - ``y_prob``
+ - ``y_true``
+ - ``logit``
+
+Key Components:
+ - ``TemporalConvBlock``: causal temporal convolution per feature
+ - ``PointwiseConvBlock``: per-time-step feature interaction layer
+ - ``TPCLayer``: combined temporal + pointwise block
+ - ``TPC``: full stacked regression model
+
+Implementation Notes:
+ - The public model contract is now fully dataset-backed through
+ ``BaseModel``.
+ - The model predicts one scalar remaining LoS value per sample.
+ - Nonlinearities use ``nn.Module`` variants to remain compatible with
+ PyHealth interpretability expectations.
+ - The input ``time_series`` field is internally split into value, mask,
+ and decay channels from a [T, 3F] representation.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Sequence, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from pyhealth.models import BaseModel
+
+
+class TemporalConvBlock(nn.Module):
+ """Feature-wise or shared causal temporal convolution block.
+
+ This block applies a 1-D causal convolution over time for each feature.
+ It supports two modes:
+
+ 1. Feature-wise temporal convolution:
+ A separate ``nn.Conv1d`` is created for each feature so temporal
+ filters are not shared across features.
+
+ 2. Shared temporal convolution:
+ A single ``nn.Conv1d`` is reused for all features, allowing an
+ ablation against the feature-specific version.
+
+ Input:
+ x: Tensor of shape ``[B, T, F, C_in]``
+
+ Output:
+ y: Tensor of shape ``[B, T, F, C_out]``
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_features: int,
+ kernel_size: int,
+ dilation: int,
+ dropout: float = 0.0,
+ shared_temporal: bool = False,
+ ) -> None:
+ """Initialize the temporal convolution block.
+
+ Args:
+ in_channels: Number of per-feature input channels.
+ out_channels: Number of per-feature output channels.
+ num_features: Number of time-series features.
+ kernel_size: Temporal convolution kernel size.
+ dilation: Temporal dilation factor.
+ dropout: Dropout probability applied after convolution.
+ shared_temporal: Whether to share one temporal convolution across
+ all features.
+
+ Raises:
+ ValueError: If any required dimensional argument is invalid.
+ """
+ super().__init__()
+ if in_channels <= 0:
+ raise ValueError("in_channels must be positive")
+ if out_channels <= 0:
+ raise ValueError("out_channels must be positive")
+ if num_features <= 0:
+ raise ValueError("num_features must be positive")
+ if kernel_size <= 0:
+ raise ValueError("kernel_size must be positive")
+ if dilation <= 0:
+ raise ValueError("dilation must be positive")
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_features = num_features
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.shared_temporal = shared_temporal
+
+ if self.shared_temporal:
+ self.shared_conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ bias=True,
+ )
+ self.feature_convs = None
+ else:
+ self.shared_conv = None
+ self.feature_convs = nn.ModuleList(
+ [
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ bias=True,
+ )
+ for _ in range(num_features)
+ ]
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply causal temporal convolution to each feature.
+
+ Args:
+ x: Input tensor of shape ``[B, T, F, C_in]``.
+
+ Returns:
+ Output tensor of shape ``[B, T, F, C_out]``.
+
+ Raises:
+ ValueError: If the input tensor has incompatible shape.
+ """
+ if x.ndim != 4:
+ raise ValueError(
+ f"Expected x to have shape [B, T, F, C], got {tuple(x.shape)}"
+ )
+
+ _, _, num_features, in_channels = x.shape
+ if num_features != self.num_features:
+ raise ValueError(
+ f"Expected {self.num_features} features, got {num_features}"
+ )
+ if in_channels != self.in_channels:
+ raise ValueError(
+ f"Expected {self.in_channels} input channels, got {in_channels}"
+ )
+
+ outputs = []
+ left_pad = self.dilation * (self.kernel_size - 1)
+
+ for feat_idx in range(self.num_features):
+ feat_x = x[:, :, feat_idx, :].transpose(1, 2) # [B, C_in, T]
+ feat_x = F.pad(feat_x, (left_pad, 0))
+
+ if self.shared_temporal:
+ feat_y = self.shared_conv(feat_x)
+ else:
+ feat_y = self.feature_convs[feat_idx](feat_x)
+
+ feat_y = feat_y.transpose(1, 2) # [B, T, C_out]
+ outputs.append(feat_y.unsqueeze(2)) # [B, T, 1, C_out]
+
+ y = torch.cat(outputs, dim=2) # [B, T, F, C_out]
+ y = self.dropout(y)
+ return y
+
+
+class PointwiseConvBlock(nn.Module):
+ """Pointwise transformation applied independently at each time step.
+
+ This block is implemented as a linear layer operating on the flattened
+ per-time-step representation. It is equivalent in spirit to a per-time-step
+ 1x1 convolution across the feature/channel dimension.
+
+ Input:
+ x: Tensor of shape ``[B, T, D_in]``
+
+ Output:
+ y: Tensor of shape ``[B, T, D_out]``
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ dropout: float = 0.0,
+ ) -> None:
+ """Initialize the pointwise block.
+
+ Args:
+ input_dim: Input dimension per time step.
+ output_dim: Output dimension per time step.
+ dropout: Dropout probability applied after projection.
+
+ Raises:
+ ValueError: If either dimension is not positive.
+ """
+ super().__init__()
+ if input_dim <= 0:
+ raise ValueError("input_dim must be positive")
+ if output_dim <= 0:
+ raise ValueError("output_dim must be positive")
+
+ self.linear = nn.Linear(input_dim, output_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply the pointwise transformation.
+
+ Args:
+ x: Input tensor of shape ``[B, T, D_in]``.
+
+ Returns:
+ Output tensor of shape ``[B, T, D_out]``.
+
+ Raises:
+ ValueError: If the input tensor shape is invalid.
+ """
+ if x.ndim != 3:
+ raise ValueError(
+ f"Expected x to have shape [B, T, D], got {tuple(x.shape)}"
+ )
+ y = self.linear(x)
+ y = self.dropout(y)
+ return y
+
+
+class TPCLayer(nn.Module):
+ """One Temporal Pointwise Convolution layer.
+
+ This layer combines:
+ - optional temporal convolution branch
+ - optional pointwise branch
+ - optional concatenative skip connections
+
+ Input:
+ x: Tensor of shape ``[B, T, F, C_in]``
+ decay: Tensor of shape ``[B, T, F]`` or ``None``
+ static: Tensor of shape ``[B, S]`` or ``None``
+
+ Output:
+ fused: Tensor of shape ``[B, T, F, C_out]``
+ """
+
+ def __init__(
+ self,
+ num_features: int,
+ in_channels: int,
+ temporal_channels: int,
+ pointwise_channels: int,
+ static_dim: int,
+ kernel_size: int,
+ dilation: int,
+ dropout: float = 0.0,
+ use_decay_in_pointwise: bool = True,
+ shared_temporal: bool = False,
+ use_temporal: bool = True,
+ use_pointwise: bool = True,
+ use_skip_connections: bool = True,
+ ) -> None:
+ """Initialize the TPC layer.
+
+ Args:
+ num_features: Number of time-series features.
+ in_channels: Number of input channels per feature.
+ temporal_channels: Number of temporal branch output channels.
+ pointwise_channels: Number of pointwise branch output channels.
+ static_dim: Static feature dimension.
+ kernel_size: Temporal kernel size.
+ dilation: Temporal dilation factor.
+ dropout: Dropout probability.
+ use_decay_in_pointwise: Whether decay indicators are concatenated
+ into the pointwise branch input.
+ shared_temporal: Whether temporal filters are shared across
+ features.
+ use_temporal: Whether to enable the temporal branch.
+ use_pointwise: Whether to enable the pointwise branch.
+ use_skip_connections: Whether to concatenate the prior input
+ representation into the layer output.
+
+ Raises:
+ ValueError: If layer dimensions are invalid or both branches
+ are disabled.
+ """
+ super().__init__()
+ if num_features <= 0:
+ raise ValueError("num_features must be positive")
+ if in_channels <= 0:
+ raise ValueError("in_channels must be positive")
+ if temporal_channels <= 0:
+ raise ValueError("temporal_channels must be positive")
+ if pointwise_channels <= 0:
+ raise ValueError("pointwise_channels must be positive")
+ if static_dim < 0:
+ raise ValueError("static_dim must be non-negative")
+
+ self.num_features = num_features
+ self.in_channels = in_channels
+ self.temporal_channels = temporal_channels
+ self.pointwise_channels = pointwise_channels
+ self.static_dim = static_dim
+ self.use_decay_in_pointwise = use_decay_in_pointwise
+ self.shared_temporal = shared_temporal
+ self.use_temporal = use_temporal
+ self.use_pointwise = use_pointwise
+ self.use_skip_connections = use_skip_connections
+
+ if not self.use_temporal and not self.use_pointwise:
+ raise ValueError(
+ "At least one of use_temporal or use_pointwise must be True"
+ )
+
+ self.output_channels = 0
+ if self.use_skip_connections:
+ self.output_channels += in_channels
+ if self.use_temporal:
+ self.output_channels += temporal_channels
+ if self.use_pointwise:
+ self.output_channels += pointwise_channels
+
+ if self.use_temporal:
+ self.temporal = TemporalConvBlock(
+ in_channels=in_channels,
+ out_channels=temporal_channels,
+ num_features=num_features,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ dropout=dropout,
+ shared_temporal=shared_temporal,
+ )
+ else:
+ self.temporal = None
+
+ if self.use_pointwise:
+ point_input_dim = (num_features * in_channels) + static_dim
+ if use_decay_in_pointwise:
+ point_input_dim += num_features
+
+ self.pointwise = PointwiseConvBlock(
+ input_dim=point_input_dim,
+ output_dim=pointwise_channels,
+ dropout=dropout,
+ )
+ else:
+ self.pointwise = None
+
+ self.activation = nn.ReLU()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ decay: Optional[torch.Tensor] = None,
+ static: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Apply the TPC layer.
+
+ Args:
+ x: Input tensor of shape ``[B, T, F, C_in]``.
+ decay: Optional decay tensor of shape ``[B, T, F]``.
+ static: Optional static tensor of shape ``[B, S]``.
+
+ Returns:
+ Fused tensor of shape ``[B, T, F, C_out]``.
+
+ Raises:
+ ValueError: If tensor shapes are incompatible.
+ """
+ if x.ndim != 4:
+ raise ValueError(
+ f"Expected x to have shape [B, T, F, C], got {tuple(x.shape)}"
+ )
+
+ bsz, seq_len, num_features, in_channels = x.shape
+ if num_features != self.num_features:
+ raise ValueError(
+ f"Expected {self.num_features} features, got {num_features}"
+ )
+ if in_channels != self.in_channels:
+ raise ValueError(
+ f"Expected {self.in_channels} input channels, got {in_channels}"
+ )
+
+ if decay is not None:
+ if decay.ndim != 3:
+ raise ValueError(
+ f"Expected decay to have shape [B, T, F], got {tuple(decay.shape)}"
+ )
+ if decay.shape[:3] != (bsz, seq_len, num_features):
+ raise ValueError(
+ "Expected decay shape "
+ f"{(bsz, seq_len, num_features)}, got {tuple(decay.shape)}"
+ )
+
+ if static is not None:
+ if static.ndim != 2:
+ raise ValueError(
+ f"Expected static to have shape [B, S], got {tuple(static.shape)}"
+ )
+ if static.shape[0] != bsz:
+ raise ValueError(
+ f"Expected static batch size {bsz}, got {static.shape[0]}"
+ )
+
+ parts_to_concat = []
+
+ if self.use_skip_connections:
+ parts_to_concat.append(x)
+
+ if self.use_temporal:
+ temp_out = self.temporal(x)
+ parts_to_concat.append(temp_out)
+
+ if self.use_pointwise:
+ flat_x = x.reshape(bsz, seq_len, num_features * in_channels)
+ point_parts = [flat_x]
+
+ if static is not None:
+ static_rep = static.unsqueeze(1).expand(-1, seq_len, -1)
+ point_parts.append(static_rep)
+
+ if self.use_decay_in_pointwise:
+ if decay is None:
+ raise ValueError(
+ "decay must be provided when use_decay_in_pointwise=True"
+ )
+ point_parts.append(decay)
+
+ point_in = torch.cat(point_parts, dim=-1)
+ point_out = self.pointwise(point_in)
+ point_broadcast = point_out.unsqueeze(2).expand(
+ -1, -1, num_features, -1
+ )
+ parts_to_concat.append(point_broadcast)
+
+ fused = torch.cat(parts_to_concat, dim=-1)
+ fused = self.activation(fused)
+ return fused
+
+
+class TPC(BaseModel):
+ """Temporal Pointwise Convolution model for scalar LoS regression.
+
+ This is a true dataset-backed PyHealth ``BaseModel`` implementation.
+
+ Expected task contract:
+ - input field ``time_series`` containing per-sample history encoded as
+ ``[T, 3F]`` with interleaved [value, mask, decay] channels
+ - input field ``static`` containing optional static features
+ - output field ``target_los_hours`` declared as ``"regression"``
+
+ The model predicts one scalar remaining length-of-stay value for each
+ sample using the full observed history in that sample.
+
+ Args:
+ dataset: PyHealth ``SampleDataset`` produced by ``dataset.set_task()``.
+ input_dim: Number of base time-series features ``F`` before expansion
+ into [value, mask, decay].
+ static_dim: Static feature dimension.
+ temporal_channels: Temporal branch output channels per layer.
+ pointwise_channels: Pointwise branch output channels per layer.
+ num_layers: Number of stacked TPC layers.
+ kernel_size: Temporal convolution kernel size.
+ fc_dim: Hidden dimension in the final regression head.
+ dropout: Dropout probability.
+ loss_name: Loss function name. Use ``"msle"`` for mean squared
+ logarithmic error or ``"mse"`` for mean squared error.
+ use_decay_in_pointwise: Whether to inject decay channels into the
+ pointwise branch.
+ positive_output: Whether to enforce non-negative predictions with
+ ``Softplus``.
+ shared_temporal: Whether temporal filters are shared across features.
+ use_temporal: Whether to enable the temporal branch.
+ use_pointwise: Whether to enable the pointwise branch.
+ use_skip_connections: Whether to use concatenative skip connections.
+
+ Notes:
+ - This model follows the standard PyHealth ``forward(**kwargs)``
+ contract.
+ - The output head size is derived from
+ ``BaseModel.get_output_size()``.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ input_dim: int,
+ static_dim: int = 0,
+ temporal_channels: int = 8,
+ pointwise_channels: int = 8,
+ num_layers: int = 3,
+ kernel_size: int = 3,
+ fc_dim: int = 32,
+ dropout: float = 0.1,
+ loss_name: str = "msle",
+ use_decay_in_pointwise: bool = True,
+ positive_output: bool = True,
+ shared_temporal: bool = False,
+ use_temporal: bool = True,
+ use_pointwise: bool = True,
+ use_skip_connections: bool = True,
+ ) -> None:
+ """Initialize the TPC model."""
+ super().__init__(dataset=dataset)
+
+ if input_dim <= 0:
+ raise ValueError("input_dim must be positive")
+ if static_dim < 0:
+ raise ValueError("static_dim must be non-negative")
+ if temporal_channels <= 0:
+ raise ValueError("temporal_channels must be positive")
+ if pointwise_channels <= 0:
+ raise ValueError("pointwise_channels must be positive")
+ if num_layers <= 0:
+ raise ValueError("num_layers must be positive")
+ if kernel_size <= 0:
+ raise ValueError("kernel_size must be positive")
+ if fc_dim <= 0:
+ raise ValueError("fc_dim must be positive")
+ if loss_name not in {"msle", "mse"}:
+ raise ValueError("loss_name must be one of {'msle', 'mse'}")
+
+ self.label_key = self.label_keys[0]
+
+ required_feature_keys = {"time_series", "static"}
+ missing = required_feature_keys.difference(set(self.feature_keys))
+ if missing:
+ raise ValueError(
+ "TPC requires task input_schema to contain the feature keys "
+ f"{sorted(required_feature_keys)}; missing {sorted(missing)}"
+ )
+
+ self.input_dim = input_dim
+ self.static_dim = static_dim
+ self.temporal_channels = temporal_channels
+ self.pointwise_channels = pointwise_channels
+ self.num_layers = num_layers
+ self.kernel_size = kernel_size
+ self.fc_dim = fc_dim
+ self.dropout = dropout
+ self.loss_name = loss_name
+ self.use_decay_in_pointwise = use_decay_in_pointwise
+ self.positive_output = positive_output
+ self.shared_temporal = shared_temporal
+ self.use_temporal = use_temporal
+ self.use_pointwise = use_pointwise
+ self.use_skip_connections = use_skip_connections
+
+ layers = []
+ in_channels = 2 # value + decay
+
+ for layer_idx in range(num_layers):
+ layer = TPCLayer(
+ num_features=input_dim,
+ in_channels=in_channels,
+ temporal_channels=temporal_channels,
+ pointwise_channels=pointwise_channels,
+ static_dim=static_dim,
+ kernel_size=kernel_size,
+ dilation=layer_idx + 1,
+ dropout=dropout,
+ use_decay_in_pointwise=use_decay_in_pointwise,
+ shared_temporal=shared_temporal,
+ use_temporal=use_temporal,
+ use_pointwise=use_pointwise,
+ use_skip_connections=use_skip_connections,
+ )
+ layers.append(layer)
+ in_channels = layer.output_channels
+
+ self.layers = nn.ModuleList(layers)
+
+ final_input_dim = (input_dim * in_channels) + static_dim
+ self.final_fc1 = nn.Linear(final_input_dim, fc_dim)
+ self.final_fc2 = nn.Linear(fc_dim, self.get_output_size())
+
+ self.relu = nn.ReLU()
+ self.softplus = nn.Softplus()
+
+ def _unpack_feature_value(self, key: str, feature: Any) -> torch.Tensor:
+ """Extract the processor 'value' tensor for a feature key.
+
+ PyHealth passes either:
+ - a raw tensor, or
+ - a tuple aligned with the processor schema
+
+ Args:
+ key: Feature key in the task schema.
+ feature: Batch feature value from ``kwargs``.
+
+ Returns:
+ The tensor corresponding to the processor's ``value`` field.
+
+ Raises:
+ ValueError: If the processor schema does not contain ``value``.
+ """
+ if isinstance(feature, torch.Tensor):
+ return feature
+
+ if not isinstance(feature, (tuple, list)):
+ raise ValueError(
+ f"Expected feature '{key}' to be a Tensor or tuple/list, "
+ f"got {type(feature).__name__}"
+ )
+
+ schema = self.dataset.input_processors[key].schema()
+ if "value" not in schema:
+ raise ValueError(
+ f"Processor schema for feature '{key}' does not contain 'value': "
+ f"{schema}"
+ )
+ return feature[schema.index("value")]
+
+ def _split_value_mask_decay(
+ self,
+ time_series: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Split a batch time series tensor into values, masks, and decay.
+
+ Supported input layouts:
+ - ``[B, T, 3F]`` with interleaved [value, mask, decay]
+ - ``[B, T, F, 3]`` with final channel order [value, mask, decay]
+
+ Args:
+ time_series: Batch time series tensor.
+
+ Returns:
+ Tuple ``(values, masks, decay)`` each with shape ``[B, T, F]``.
+
+ Raises:
+ ValueError: If the tensor shape is incompatible.
+ """
+ if time_series.ndim == 3:
+ batch_size, seq_len, feat_dim = time_series.shape
+ if feat_dim % 3 != 0:
+ raise ValueError(
+ "Expected time_series last dimension divisible by 3 for "
+ f"[value, mask, decay], got {feat_dim}"
+ )
+ num_features = feat_dim // 3
+ if num_features != self.input_dim:
+ raise ValueError(
+ f"Expected input_dim={self.input_dim}, got {num_features}"
+ )
+
+ values = []
+ masks = []
+ decay = []
+ for feature_idx in range(num_features):
+ base = feature_idx * 3
+ values.append(time_series[:, :, base].unsqueeze(-1))
+ masks.append(time_series[:, :, base + 1].unsqueeze(-1))
+ decay.append(time_series[:, :, base + 2].unsqueeze(-1))
+
+ return (
+ torch.cat(values, dim=-1),
+ torch.cat(masks, dim=-1),
+ torch.cat(decay, dim=-1),
+ )
+
+ if time_series.ndim == 4:
+ batch_size, seq_len, num_features, channels = time_series.shape
+ if channels != 3:
+ raise ValueError(
+ "Expected time_series channel dimension size 3 for "
+ f"[value, mask, decay], got {channels}"
+ )
+ if num_features != self.input_dim:
+ raise ValueError(
+ f"Expected input_dim={self.input_dim}, got {num_features}"
+ )
+ values = time_series[:, :, :, 0]
+ masks = time_series[:, :, :, 1]
+ decay = time_series[:, :, :, 2]
+ return values, masks, decay
+
+ raise ValueError(
+ "Expected time_series to have shape [B, T, 3F] or [B, T, F, 3], "
+ f"got {tuple(time_series.shape)}"
+ )
+
+ def forward(self, **kwargs) -> dict[str, torch.Tensor]:
+ """Run the TPC model forward under the PyHealth BaseModel contract.
+
+ Expected kwargs:
+ - ``time_series``: tensor or processor tuple containing the
+ [value, mask, decay] history representation
+ - ``static``: static feature tensor or processor tuple
+ - ``target_los_hours``: regression targets
+
+ Returns:
+ Dictionary containing:
+ - ``loss``: scalar regression loss
+ - ``y_prob``: prepared output probabilities/identity values
+ - ``y_true``: ground-truth labels
+ - ``logit``: raw model outputs
+
+ Raises:
+ ValueError: If required inputs are missing or malformed.
+ """
+ if "time_series" not in kwargs:
+ raise ValueError("Missing required batch field 'time_series'")
+ if "static" not in kwargs:
+ raise ValueError("Missing required batch field 'static'")
+ if self.label_key not in kwargs:
+ raise ValueError(
+ f"Missing required label field '{self.label_key}' in batch"
+ )
+
+ time_series = self._unpack_feature_value("time_series", kwargs["time_series"])
+ static = self._unpack_feature_value("static", kwargs["static"])
+
+ if not isinstance(time_series, torch.Tensor):
+ raise ValueError("'time_series' value must be a Tensor after unpacking")
+ if not isinstance(static, torch.Tensor):
+ raise ValueError("'static' value must be a Tensor after unpacking")
+
+ time_series = time_series.float().to(self.device)
+ static = static.float().to(self.device)
+
+ if time_series.ndim not in {3, 4}:
+ raise ValueError(
+ "Expected time_series batch tensor to have 3 or 4 dimensions, "
+ f"got {tuple(time_series.shape)}"
+ )
+
+ if static.ndim != 2:
+ raise ValueError(
+ f"Expected static to have shape [B, S], got {tuple(static.shape)}"
+ )
+ if static.shape[1] != self.static_dim:
+ raise ValueError(
+ f"Expected static_dim={self.static_dim}, got {static.shape[1]}"
+ )
+
+ x_values, _, x_decay = self._split_value_mask_decay(time_series)
+
+ x = torch.stack([x_values, x_decay], dim=-1) # [B, T, F, 2]
+
+ for layer in self.layers:
+ x = layer(x, decay=x_decay, static=static)
+
+ batch_size, seq_len, _, _ = x.shape
+ last_x = x[:, -1, :, :].reshape(batch_size, -1)
+
+ if static is not None:
+ last_x = torch.cat([last_x, static], dim=-1)
+
+ hidden = self.relu(self.final_fc1(last_x))
+ logits = self.final_fc2(hidden)
+
+ if self.positive_output:
+ logits = self.softplus(logits)
+
+ y_true = kwargs[self.label_key].float().to(self.device)
+ if y_true.ndim == 1:
+ y_true = y_true.unsqueeze(-1)
+ elif y_true.ndim > 2:
+ raise ValueError(
+ f"Expected y_true to have shape [B] or [B, 1], got {tuple(y_true.shape)}"
+ )
+
+ if self.loss_name == "msle":
+ loss = F.mse_loss(torch.log1p(logits), torch.log1p(y_true))
+ elif self.loss_name == "mse":
+ loss = F.mse_loss(logits, y_true)
+ else:
+ raise ValueError(
+ f"Unsupported loss_name '{self.loss_name}'. Expected 'msle' or 'mse'."
+ )
+
+ return {
+ "loss": loss,
+ "y_prob": self.prepare_y_prob(logits),
+ "y_true": y_true,
+ "logit": logits,
+ }
\ No newline at end of file
diff --git a/pyhealth/tasks/hourly_los.py b/pyhealth/tasks/hourly_los.py
new file mode 100644
index 000000000..db0002c2c
--- /dev/null
+++ b/pyhealth/tasks/hourly_los.py
@@ -0,0 +1,1060 @@
+"""
+Hourly ICU remaining length-of-stay (LoS) task implementation for PyHealth.
+
+This module defines the ``HourlyLOSEICU`` task, which constructs supervised
+samples for predicting the remaining ICU length of stay at hourly intervals
+from multivariate EHR time-series data.
+
+The implementation is designed to align with the preprocessing strategy
+described in:
+
+Rocheteau, E., Liò, P., and Hyland, S. (2021).
+"Temporal Pointwise Convolutional Networks for Length of Stay Prediction
+in the Intensive Care Unit."
+
+Overview:
+ For each ICU stay, this task:
+
+ 1. Extracts observations from both pre-ICU and ICU time windows
+ (typically up to 24 hours before ICU admission).
+ 2. Buckets observations into hourly intervals.
+ 3. Retains the most recent measurement within each hour.
+ 4. Applies forward-filling to handle missing values.
+ 5. Computes decay features to represent time since last observation.
+ 6. Removes pre-ICU rows after feature construction.
+ 7. Emits one supervised sample per prediction hour.
+
+Task Contract:
+ Declared input schema:
+ - ``time_series``: processor-backed timeseries field encoding
+ per-hour [value, mask, decay] channels for each feature
+ - ``static``: processor-backed tensor field for numeric and
+ one-hot encoded static features
+
+ Declared output schema:
+ - ``target_los_hours``: scalar regression target for remaining
+ ICU length of stay at the current prediction hour
+
+Metadata retained in each sample for analysis/debugging:
+ - ``target_los_sequence``
+ - ``feature_names``
+ - ``history_hours``
+ - ``categorical_static_raw``
+ - ``diagnosis_raw``
+
+Implementation Notes:
+ - The public training label is ``target_los_hours``.
+ - ``target_los_sequence`` is retained only as auxiliary metadata and is
+ not part of the formal output schema.
+ - The generated ``time_series`` representation uses [value, mask, decay]
+ channel order for each feature, which the TPC model unpacks internally.
+"""
+
+from __future__ import annotations
+
+import math
+from datetime import datetime, timedelta
+from typing import Any, Dict, List, Optional, Tuple
+
+from .base_task import BaseTask
+
+
+class HourlyLOSEICU(BaseTask):
+ """Hourly remaining ICU length-of-stay regression task.
+
+ This task builds hourly time-series samples for ICU remaining length-of-stay
+ prediction from either eICU-style or MIMIC-IV-style patient records.
+
+ For each ICU stay, the task:
+ 1. Extracts observations from up to ``pre_icu_hours`` before ICU
+ admission through ICU discharge.
+ 2. Buckets measurements into hourly bins.
+ 3. Keeps the most recent measurement within each hour.
+ 4. Forward-fills missing values.
+ 5. Computes decay features based on hours since the last true
+ observation.
+ 6. Drops pre-ICU rows after forward-fill so the final sequence is
+ ICU-only.
+ 7. Produces one supervised sample per prediction hour with a scalar
+ remaining-LoS label.
+
+ Declared sample fields:
+ Inputs:
+ - ``time_series``: shape [T, 3F] with per-feature
+ [value, mask, decay]
+ - ``static``: encoded static feature vector
+
+ Outputs:
+ - ``target_los_hours``: scalar remaining ICU LoS target
+
+ Notes:
+ - This task is intended to serve as the formal task/model contract for
+ the TPC ``BaseModel`` implementation.
+ - Categorical static features are one-hot encoded with task-local
+ vocabularies.
+ - Diagnosis strings may be retained as metadata when requested.
+ """
+
+ task_name: str = "HourlyLOSEICU"
+
+ def __init__(
+ self,
+ diagnosis_tables: Optional[List[str]] = None,
+ include_diagnoses: bool = False,
+ diagnosis_time_limit_hours: int = 5,
+ time_series_tables: Optional[List[str]] = None,
+ time_series_features: Optional[Dict[str, List[str]]] = None,
+ numeric_static_features: Optional[List[str]] = None,
+ categorical_static_features: Optional[List[str]] = None,
+ min_history_hours: int = 5,
+ max_hours: int = 48,
+ pre_icu_hours: int = 24,
+ ) -> None:
+ """Initialize the hourly ICU LoS task.
+
+ Args:
+ diagnosis_tables: Diagnosis-related tables to inspect for diagnosis
+ extraction, primarily for eICU.
+ include_diagnoses: Whether to extract and attach diagnosis strings
+ as metadata.
+ diagnosis_time_limit_hours: Maximum diagnosis time from ICU
+ admission to include, in hours.
+ time_series_tables: Tables containing time-series observations.
+ time_series_features: Mapping from table name to the list of
+ feature names to extract from that table.
+ numeric_static_features: Static numeric features to encode directly.
+ categorical_static_features: Static categorical features to one-hot
+ encode using task-local vocabularies.
+ min_history_hours: Minimum history length required before a sample
+ is emitted.
+ max_hours: Maximum ICU hours to keep from a stay.
+ pre_icu_hours: Number of hours before ICU admission to include
+ during extraction and forward-fill before cropping back to
+ ICU-only rows.
+ """
+ self.diagnosis_tables = diagnosis_tables or []
+ self.include_diagnoses = include_diagnoses
+ self.diagnosis_time_limit_hours = diagnosis_time_limit_hours
+ self.time_series_tables = time_series_tables or ["lab"]
+ self.time_series_features = time_series_features or {"lab": ["-basos"]}
+ self.numeric_static_features = numeric_static_features or []
+ self.categorical_static_features = categorical_static_features or []
+ self.min_history_hours = min_history_hours
+ self.max_hours = max_hours
+ self.pre_icu_hours = pre_icu_hours
+
+ self.static_vocab: Dict[str, Dict[Any, int]] = {
+ feature: {} for feature in self.categorical_static_features
+ }
+
+ # Formal task/model contract.
+ self.input_schema = {
+ "time_series": "tensor",
+ "static": "tensor",
+ }
+
+ self.output_schema = {
+ "target_los_hours": "regression",
+ }
+
+ def pre_filter(self, global_event_df):
+ """Optionally filter the global event dataframe before task processing.
+
+ Args:
+ global_event_df: The global event dataframe prepared by the dataset.
+
+ Returns:
+ The input dataframe unchanged.
+ """
+ return global_event_df
+
+ def _split_hierarchical_diagnosis(self, raw: Any) -> List[str]:
+ """Split a hierarchical diagnosis string into cumulative prefix tokens.
+
+ Example:
+ ``"a | b | c" -> ["a", "a | b", "a | b | c"]``
+
+ Args:
+ raw: Raw diagnosis string or other value.
+
+ Returns:
+ A list of cumulative hierarchical diagnosis prefixes.
+ """
+ if raw is None:
+ return []
+
+ s = str(raw).strip()
+ if not s:
+ return []
+
+ parts = [p.strip().lower() for p in s.split("|") if p.strip()]
+ if not parts:
+ return []
+
+ prefixes = []
+ current = []
+ for part in parts:
+ current.append(part)
+ prefixes.append(" | ".join(current))
+ return prefixes
+
+ def _extract_eicu_diagnoses(self, patient) -> List[str]:
+ """Extract hierarchical diagnosis tokens for an eICU patient.
+
+ Only diagnoses within ``diagnosis_time_limit_hours`` are included.
+
+ Args:
+ patient: PyHealth patient object.
+
+ Returns:
+ A sorted list of unique diagnosis tokens.
+ """
+ if not self.include_diagnoses:
+ return []
+
+ diagnosis_tokens = set()
+
+ for table in self.diagnosis_tables:
+ try:
+ events = patient.get_events(table)
+ except Exception:
+ events = []
+
+ for event in events:
+ attr = event.attr_dict
+
+ offset_minutes = None
+ for offset_key in [
+ "pasthistoryoffset",
+ "admitdxenteredoffset",
+ "diagnosisoffset",
+ ]:
+ offset_minutes = self._safe_float(attr.get(offset_key))
+ if offset_minutes is not None:
+ break
+
+ if offset_minutes is None:
+ continue
+
+ offset_hours = offset_minutes / 60.0
+ if offset_hours > self.diagnosis_time_limit_hours:
+ continue
+
+ raw_candidates = [
+ attr.get("pasthistorypath"),
+ attr.get("admitdxpath"),
+ attr.get("diagnosisstring"),
+ attr.get("diagnosispath"),
+ ]
+
+ raw_value = None
+ for candidate in raw_candidates:
+ if candidate is not None and str(candidate).strip():
+ raw_value = candidate
+ break
+
+ if raw_value is None:
+ continue
+
+ for token in self._split_hierarchical_diagnosis(raw_value):
+ diagnosis_tokens.add(token)
+
+ return sorted(diagnosis_tokens)
+
+ def _safe_float(self, value: Any) -> Optional[float]:
+ """Convert a value to a finite float if possible.
+
+ Args:
+ value: Input value.
+
+ Returns:
+ A finite float, or ``None`` if conversion fails or the value is
+ NaN/inf.
+ """
+ try:
+ if value is None:
+ return None
+ x = float(value)
+ if math.isnan(x) or math.isinf(x):
+ return None
+ return x
+ except Exception:
+ return None
+
+ def _safe_datetime(self, value: Any) -> Optional[datetime]:
+ """Convert a value to a datetime if possible.
+
+ Supports ISO strings and several common datetime formats.
+
+ Args:
+ value: Input value.
+
+ Returns:
+ A datetime object, or ``None`` if parsing fails.
+ """
+ if value is None:
+ return None
+ if isinstance(value, datetime):
+ return value
+
+ s = str(value).strip()
+ if not s:
+ return None
+
+ s = s.replace("Z", "+00:00")
+ try:
+ return datetime.fromisoformat(s)
+ except Exception:
+ pass
+
+ fmts = [
+ "%Y-%m-%d %H:%M:%S",
+ "%Y-%m-%d %H:%M:%S.%f",
+ "%Y-%m-%d",
+ ]
+ for fmt in fmts:
+ try:
+ return datetime.strptime(s, fmt)
+ except Exception:
+ continue
+ return None
+
+ def _norm_name(self, x: Any) -> str:
+ """Normalize a feature name to a lowercase stripped string.
+
+ Args:
+ x: Raw feature name.
+
+ Returns:
+ Normalized feature name string.
+ """
+ if x is None:
+ return ""
+ return str(x).strip().lower()
+
+ def _encode_static(self, attr: Dict[str, Any]) -> List[float]:
+ """Encode static numeric and categorical features into a flat vector.
+
+ Numeric features are inserted directly after safe float conversion.
+ Categorical features are one-hot encoded with a task-local vocabulary.
+
+ Args:
+ attr: Source attribute dictionary.
+
+ Returns:
+ Encoded static feature vector.
+ """
+ numeric = []
+ categorical = []
+
+ for feature_name in self.numeric_static_features:
+ val = self._safe_float(attr.get(feature_name))
+ if val is None:
+ val = 0.0
+ numeric.append(float(val))
+
+ for feature_name in self.categorical_static_features:
+ raw = attr.get(feature_name)
+ category = "__MISSING__" if raw is None else str(raw)
+
+ vocab = self.static_vocab.setdefault(feature_name, {})
+ if category not in vocab:
+ vocab[category] = len(vocab)
+
+ one_hot = [0.0] * len(vocab)
+ one_hot[vocab[category]] = 1.0
+ categorical.extend(one_hot)
+
+ return numeric + categorical
+
+ def _build_feature_index(self) -> Tuple[List[str], Dict[str, int]]:
+ """Build the normalized feature list and index mapping.
+
+ Returns:
+ A tuple of:
+ - normalized feature names in deterministic order
+ - mapping from feature name to feature index
+
+ Raises:
+ ValueError: If no time-series features are defined.
+ """
+ names = []
+ for table in self.time_series_tables:
+ names.extend(self.time_series_features.get(table, []))
+
+ normalized_names = [self._norm_name(name) for name in names]
+
+ if len(normalized_names) == 0:
+ raise ValueError("No time-series features defined.")
+
+ if len(set(normalized_names)) != len(normalized_names):
+ seen = set()
+ deduped = []
+ for name in normalized_names:
+ if name not in seen:
+ seen.add(name)
+ deduped.append(name)
+ normalized_names = deduped
+
+ return normalized_names, {name: i for i, name in enumerate(normalized_names)}
+
+ def _combine_value_mask_decay(
+ self,
+ filled: List[List[float]],
+ mask: List[List[float]],
+ decay: List[List[float]],
+ ) -> List[List[float]]:
+ """Interleave filled values, masks, and decay as [value, mask, decay].
+
+ Args:
+ filled: Forward-filled values, shape [T, F].
+ mask: Observation mask, shape [T, F].
+ decay: Decay features, shape [T, F].
+
+ Returns:
+ Combined feature matrix of shape [T, 3F].
+ """
+ combined = []
+ for hour_idx in range(len(filled)):
+ row = []
+ for feat_idx in range(len(filled[hour_idx])):
+ row.append(filled[hour_idx][feat_idx])
+ row.append(mask[hour_idx][feat_idx])
+ row.append(decay[hour_idx][feat_idx])
+ combined.append(row)
+ return combined
+
+ def _build_feature_names(self, normalized_names: List[str]) -> List[str]:
+ """Build output feature names for value, mask, and decay channels.
+
+ Args:
+ normalized_names: Base time-series feature names.
+
+ Returns:
+ Expanded feature names in
+ ``[name_val, name_mask, name_decay]`` order.
+ """
+ feature_names = []
+ for name in normalized_names:
+ feature_names.extend(
+ [f"{name}_val", f"{name}_mask", f"{name}_decay"]
+ )
+ return feature_names
+
+ def _make_hourly_tensor(
+ self,
+ observations: List[Tuple[int, int, float, float]],
+ usable_hours: int,
+ num_features: int,
+ ) -> List[List[float]]:
+ """Build an hourly [value, mask, decay] tensor from observations.
+
+ Each observation tuple is:
+ ``(hour_index, feature_index, value, precise_offset_hours)``
+
+ For each hour and feature, the most recent measurement within that hour
+ is retained. Missing values are forward-filled. Mask indicates whether
+ the value was observed in that exact hour. Decay follows ``0.75 ** j``
+ where ``j`` is the number of hours since the last real observation.
+
+ Args:
+ observations: Bucketed observations.
+ usable_hours: Number of hours in the timeline being built.
+ num_features: Number of time-series features.
+
+ Returns:
+ Combined [T, 3F] tensor as a Python list of rows.
+ """
+ latest_vals = [[None] * num_features for _ in range(usable_hours)]
+ latest_time = [[-float("inf")] * num_features for _ in range(usable_hours)]
+
+ for hour_idx, feat_idx, value, precise_offset in observations:
+ if hour_idx < 0 or hour_idx >= usable_hours:
+ continue
+ if precise_offset >= latest_time[hour_idx][feat_idx]:
+ latest_time[hour_idx][feat_idx] = precise_offset
+ latest_vals[hour_idx][feat_idx] = value
+
+ filled = []
+ mask = []
+ decay = []
+
+ last_val = [0.0] * num_features
+ last_seen = [None] * num_features
+
+ for hour_idx in range(usable_hours):
+ filled_row = []
+ mask_row = []
+ decay_row = []
+
+ for feat_idx in range(num_features):
+ value = latest_vals[hour_idx][feat_idx]
+
+ if value is not None:
+ last_val[feat_idx] = value
+ last_seen[feat_idx] = hour_idx
+ filled_row.append(value)
+ mask_row.append(1.0)
+ decay_row.append(1.0)
+ else:
+ filled_row.append(last_val[feat_idx])
+ mask_row.append(0.0)
+
+ if last_seen[feat_idx] is None:
+ decay_row.append(0.0)
+ else:
+ gap = hour_idx - last_seen[feat_idx]
+ decay_row.append(float(0.75 ** gap))
+
+ filled.append(filled_row)
+ mask.append(mask_row)
+ decay.append(decay_row)
+
+ return self._combine_value_mask_decay(filled, mask, decay)
+
+ def _make_cropped_hourly_tensor(
+ self,
+ observations: List[Tuple[int, int, float, float]],
+ total_hours: float,
+ num_features: int,
+ ) -> List[List[float]]:
+ """Build an extended timeline tensor and crop it to ICU-only rows.
+
+ The timeline begins ``pre_icu_hours`` before ICU admission, allowing
+ pre-ICU observations to participate in hourly bucketing and forward-fill.
+ After the extended tensor is built, the leading pre-ICU rows are removed.
+
+ Args:
+ observations: Observations indexed on an extended timeline whose
+ zero point is ICU admission minus ``self.pre_icu_hours``.
+ total_hours: ICU stay length in hours.
+ num_features: Number of time-series features.
+
+ Returns:
+ Cropped ICU-only [T, 3F] tensor as a Python list of rows.
+ """
+ usable_hours = int(min(total_hours, self.max_hours))
+ extended_hours = self.pre_icu_hours + usable_hours
+
+ full_ts = self._make_hourly_tensor(
+ observations=observations,
+ usable_hours=extended_hours,
+ num_features=num_features,
+ )
+
+ return full_ts[self.pre_icu_hours : self.pre_icu_hours + usable_hours]
+
+ def _make_samples_for_stay(
+ self,
+ patient,
+ visit_id: Any,
+ total_hours: float,
+ static_attr: Dict[str, Any],
+ observations: List[Tuple[int, int, float, float]],
+ normalized_feature_names: List[str],
+ diagnosis_raw: Optional[List[str]] = None,
+ ) -> List[Dict[str, Any]]:
+ """Create hourly supervised samples for a single ICU stay.
+
+ Each emitted sample uses history through hour ``t`` and predicts the
+ scalar remaining length of stay at that hour.
+
+ Args:
+ patient: PyHealth patient object.
+ visit_id: Stay or visit identifier.
+ total_hours: Total ICU stay length in hours.
+ static_attr: Static attributes to encode.
+ observations: Extended-timeline observations.
+ normalized_feature_names: Base feature names.
+ diagnosis_raw: Optional diagnosis tokens.
+
+ Returns:
+ A list of task samples, one for each prediction hour from
+ ``min_history_hours`` through ``usable_hours``.
+ """
+ samples = []
+
+ if total_hours is None or total_hours < self.min_history_hours:
+ return samples
+
+ usable_hours = int(min(total_hours, self.max_hours))
+ if usable_hours < self.min_history_hours:
+ return samples
+
+ static_vec = self._encode_static(static_attr)
+ num_features = len(normalized_feature_names)
+
+ time_series = self._make_cropped_hourly_tensor(
+ observations=observations,
+ total_hours=total_hours,
+ num_features=num_features,
+ )
+ feature_names = self._build_feature_names(normalized_feature_names)
+
+ raw_cats = {}
+ for feature_name in self.categorical_static_features:
+ raw_cats[feature_name] = (
+ str(static_attr.get(feature_name))
+ if static_attr.get(feature_name) is not None
+ else "__MISSING__"
+ )
+
+ for history_hours in range(self.min_history_hours, usable_hours + 1):
+ remaining_hours = max(total_hours - history_hours, 0.0)
+
+ target_los_sequence = [
+ float(max(total_hours - hour_idx, 0.0))
+ for hour_idx in range(1, history_hours + 1)
+ ]
+
+ samples.append(
+ {
+ "patient_id": str(patient.patient_id),
+ "visit_id": str(visit_id) if visit_id is not None else "unknown",
+ "time_series": time_series[:history_hours],
+ "static": [float(x) for x in static_vec],
+ "target_los_hours": float(remaining_hours),
+ "target_los_sequence": [float(x) for x in target_los_sequence],
+ "history_hours": int(history_hours),
+ }
+ )
+
+ return samples
+
+ def _build_eicu_samples(self, patient) -> List[Dict[str, Any]]:
+ """Build hourly LoS samples for an eICU patient.
+
+ This method extracts time-series observations using eICU offset fields,
+ shifts them onto an extended timeline that includes pre-ICU hours,
+ and creates ICU-only samples after cropping.
+
+ Args:
+ patient: PyHealth patient object.
+
+ Returns:
+ List of generated task samples.
+ """
+ samples = []
+
+ try:
+ patient_events = patient.get_events("patient")
+ except Exception:
+ patient_events = []
+
+ if not patient_events:
+ return samples
+
+ anchor_attr = patient_events[0].attr_dict
+ total_minutes = self._safe_float(anchor_attr.get("unitdischargeoffset"))
+ if total_minutes is None:
+ return samples
+
+ total_hours = total_minutes / 60.0
+ if total_hours < self.min_history_hours:
+ return samples
+
+ usable_hours = int(min(total_hours, self.max_hours))
+ if usable_hours < self.min_history_hours:
+ return samples
+
+ normalized_names, feature_index = self._build_feature_index()
+ if not normalized_names:
+ return samples
+
+ extra_features = ["time in the icu", "time of day"]
+ for feature_name in extra_features:
+ if feature_name not in normalized_names:
+ normalized_names.append(feature_name)
+
+ feature_index = {
+ feature_name: idx for idx, feature_name in enumerate(normalized_names)
+ }
+
+ observations = []
+
+ for table in self.time_series_tables:
+ try:
+ events = patient.get_events(table)
+ except Exception:
+ events = []
+
+ schema = self._get_eicu_table_schema(table)
+ if schema is None:
+ continue
+
+ name_key, value_key, offset_key = schema
+
+ for event in events:
+ attr = event.attr_dict
+
+ if name_key is None and value_key is None:
+ minutes = self._safe_float(attr.get(offset_key))
+ if minutes is None:
+ continue
+
+ offset_hours = minutes / 60.0
+ extended_hour = int(offset_hours) + self.pre_icu_hours
+
+ for raw_name, raw_val in attr.items():
+ norm_name = self._norm_name(raw_name)
+ if norm_name not in feature_index:
+ continue
+
+ value = self._safe_float(raw_val)
+ if value is None:
+ continue
+
+ feat_idx = feature_index[norm_name]
+ observations.append(
+ (extended_hour, feat_idx, value, offset_hours)
+ )
+
+ time_in_icu_idx = feature_index.get("time in the icu")
+ if time_in_icu_idx is not None:
+ observations.append(
+ (
+ extended_hour,
+ time_in_icu_idx,
+ offset_hours,
+ offset_hours,
+ )
+ )
+
+ time_of_day = None
+ admit_time_str = anchor_attr.get("hospitaladmittime24")
+ if admit_time_str:
+ try:
+ h, m, s = map(int, admit_time_str.split(":"))
+ base_hour = h + m / 60.0 + s / 3600.0
+ time_of_day = (base_hour + offset_hours) % 24
+ except Exception:
+ pass
+
+ time_of_day_idx = feature_index.get("time of day")
+ if time_of_day_idx is not None and time_of_day is not None:
+ observations.append(
+ (
+ extended_hour,
+ time_of_day_idx,
+ time_of_day,
+ offset_hours,
+ )
+ )
+
+ continue
+
+ name = self._norm_name(attr.get(name_key))
+ if name not in feature_index:
+ continue
+
+ value = self._safe_float(attr.get(value_key))
+ minutes = self._safe_float(attr.get(offset_key))
+ if value is None or minutes is None:
+ continue
+
+ offset_hours = minutes / 60.0
+ extended_hour = int(offset_hours) + self.pre_icu_hours
+ feat_idx = feature_index[name]
+ observations.append((extended_hour, feat_idx, value, offset_hours))
+
+ time_in_icu_idx = feature_index.get("time in the icu")
+ if time_in_icu_idx is not None:
+ observations.append(
+ (
+ extended_hour,
+ time_in_icu_idx,
+ offset_hours,
+ offset_hours,
+ )
+ )
+
+ time_of_day = None
+ admit_time_str = anchor_attr.get("hospitaladmittime24")
+ if admit_time_str:
+ try:
+ h, m, s = map(int, admit_time_str.split(":"))
+ base_hour = h + m / 60.0 + s / 3600.0
+ time_of_day = (base_hour + offset_hours) % 24
+ except Exception:
+ pass
+
+ time_of_day_idx = feature_index.get("time of day")
+ if time_of_day_idx is not None and time_of_day is not None:
+ observations.append(
+ (
+ extended_hour,
+ time_of_day_idx,
+ time_of_day,
+ offset_hours,
+ )
+ )
+
+ visit_id = (
+ anchor_attr.get("patientunitstayid")
+ or anchor_attr.get("visit_id")
+ or patient.patient_id
+ )
+
+ static_attr = dict(anchor_attr)
+ diagnosis_raw = self._extract_eicu_diagnoses(patient)
+
+ samples.extend(
+ self._make_samples_for_stay(
+ patient=patient,
+ visit_id=visit_id,
+ total_hours=total_hours,
+ static_attr=static_attr,
+ observations=observations,
+ normalized_feature_names=normalized_names,
+ diagnosis_raw=diagnosis_raw,
+ )
+ )
+
+ return samples
+
+ def _get_eicu_table_schema(
+ self,
+ table: str,
+ ) -> Optional[Tuple[Optional[str], Optional[str], str]]:
+ """Return the schema tuple for supported eICU time-series tables.
+
+ Args:
+ table: Table name.
+
+ Returns:
+ A tuple of ``(name_key, value_key, offset_key)``, or ``None`` if the
+ table is unsupported.
+ """
+ table = str(table).strip().lower()
+
+ schema = {
+ "lab": ("labname", "labresult", "labresultoffset"),
+ "respiratorycharting": (
+ "respchartvaluelabel",
+ "respchartvalue",
+ "respchartoffset",
+ ),
+ "nursecharting": (
+ "nursingchartcelltypevallabel",
+ "nursingchartvalue",
+ "nursingchartoffset",
+ ),
+ "vitalperiodic": (None, None, "observationoffset"),
+ "vitalaperiodic": (None, None, "observationoffset"),
+ }
+
+ return schema.get(table)
+
+ def _get_mimic_table_schema(
+ self,
+ table: str,
+ ) -> Optional[Tuple[str, Optional[str], str]]:
+ """Return the schema tuple for supported MIMIC-IV time-series tables.
+
+ Args:
+ table: Table name.
+
+ Returns:
+ A tuple of ``(name_key, value_key, time_key)``, or ``None`` if the
+ table is unsupported.
+ """
+ table = str(table).strip().lower()
+
+ schema = {
+ "labevents": ("label", None, "timestamp"),
+ "chartevents": ("label", None, "timestamp"),
+ }
+
+ return schema.get(table)
+
+ def _build_mimic_samples(self, patient) -> List[Dict[str, Any]]:
+ """Build hourly LoS samples for a MIMIC-IV patient.
+
+ This method extracts time-series observations from up to
+ ``pre_icu_hours`` before ICU admission through ICU discharge, shifts
+ them onto an extended timeline, and creates ICU-only samples after
+ cropping.
+
+ Args:
+ patient: PyHealth patient object.
+
+ Returns:
+ List of generated task samples.
+ """
+ samples = []
+
+ try:
+ patient_rows = patient.get_events("patients")
+ except Exception:
+ patient_rows = []
+
+ try:
+ admission_rows = patient.get_events("admissions")
+ except Exception:
+ admission_rows = []
+
+ try:
+ icu_rows = patient.get_events("icustays")
+ except Exception:
+ icu_rows = []
+
+ if not icu_rows:
+ return samples
+
+ patient_static = patient_rows[0].attr_dict if patient_rows else {}
+
+ admissions_by_hadm = {}
+ for admission in admission_rows:
+ hadm_id = admission.attr_dict.get("hadm_id")
+ if hadm_id is not None and hadm_id not in admissions_by_hadm:
+ admissions_by_hadm[hadm_id] = admission
+
+ normalized_names, feature_index = self._build_feature_index()
+ if not normalized_names:
+ return samples
+
+ extra_features = ["time in the icu", "time of day"]
+ for feature_name in extra_features:
+ if feature_name not in normalized_names:
+ normalized_names.append(feature_name)
+
+ feature_index = {
+ feature_name: idx for idx, feature_name in enumerate(normalized_names)
+ }
+
+ for icu_event in icu_rows:
+ icu_attr = dict(icu_event.attr_dict)
+ hadm_id = icu_attr.get("hadm_id")
+ stay_id = icu_attr.get("stay_id")
+
+ intime = getattr(icu_event, "timestamp", None)
+ outtime = self._safe_datetime(icu_attr.get("outtime"))
+
+ if intime is None or outtime is None:
+ continue
+
+ total_hours = (outtime - intime).total_seconds() / 3600.0
+ if total_hours < self.min_history_hours:
+ continue
+
+ static_attr = dict(patient_static)
+ if hadm_id in admissions_by_hadm:
+ static_attr.update(admissions_by_hadm[hadm_id].attr_dict)
+ static_attr.update(icu_attr)
+
+ observations = []
+ pre_icu_start = intime - timedelta(hours=self.pre_icu_hours)
+
+ for table in self.time_series_tables:
+ try:
+ events = patient.get_events(table)
+ except Exception:
+ events = []
+
+ schema = self._get_mimic_table_schema(table)
+ if schema is None:
+ continue
+
+ name_key, _, _ = schema
+
+ for event in events:
+ attr = event.attr_dict
+
+ event_hadm_id = attr.get("hadm_id")
+ if (
+ hadm_id is not None
+ and event_hadm_id is not None
+ and str(event_hadm_id) != str(hadm_id)
+ ):
+ continue
+
+ name = self._norm_name(attr.get(name_key))
+ if name not in feature_index:
+ continue
+
+ value = self._safe_float(attr.get("valuenum"))
+ if value is None:
+ value = self._safe_float(attr.get("value"))
+ if value is None:
+ continue
+
+ event_time = getattr(event, "timestamp", None)
+ if event_time is None:
+ event_time = self._safe_datetime(attr.get("storetime"))
+ if event_time is None:
+ event_time = self._safe_datetime(attr.get("charttime"))
+ if event_time is None:
+ continue
+
+ if event_time < pre_icu_start or event_time > outtime:
+ continue
+
+ offset_hours = (
+ event_time - intime
+ ).total_seconds() / 3600.0
+ extended_hour = int(offset_hours) + self.pre_icu_hours
+ feat_idx = feature_index[name]
+ observations.append((extended_hour, feat_idx, value, offset_hours))
+
+ time_in_icu_idx = feature_index.get("time in the icu")
+ if time_in_icu_idx is not None:
+ observations.append(
+ (
+ extended_hour,
+ time_in_icu_idx,
+ offset_hours,
+ offset_hours,
+ )
+ )
+
+ time_of_day_idx = feature_index.get("time of day")
+ if time_of_day_idx is not None:
+ time_of_day = (
+ event_time.hour
+ + event_time.minute / 60.0
+ + event_time.second / 3600.0
+ )
+ observations.append(
+ (
+ extended_hour,
+ time_of_day_idx,
+ time_of_day,
+ offset_hours,
+ )
+ )
+
+ samples.extend(
+ self._make_samples_for_stay(
+ patient=patient,
+ visit_id=stay_id or hadm_id or patient.patient_id,
+ total_hours=total_hours,
+ static_attr=static_attr,
+ observations=observations,
+ normalized_feature_names=normalized_names,
+ )
+ )
+
+ return samples
+
+ def __call__(self, patient):
+ """Dispatch task construction based on patient contents.
+
+ Args:
+ patient: PyHealth patient object.
+
+ Returns:
+ A list of generated task samples for the patient.
+ """
+ try:
+ if patient.get_events("patient"):
+ return self._build_eicu_samples(patient)
+ except Exception:
+ pass
+
+ try:
+ if patient.get_events("icustays"):
+ return self._build_mimic_samples(patient)
+ except Exception:
+ pass
+
+ return []
\ No newline at end of file
diff --git a/tests/models/test_tpc.py b/tests/models/test_tpc.py
new file mode 100644
index 000000000..4155c8788
--- /dev/null
+++ b/tests/models/test_tpc.py
@@ -0,0 +1,336 @@
+"""
+Unit tests for the Temporal Pointwise Convolution (TPC) PyHealth model.
+
+This module contains fast synthetic tests for validating the dataset-backed
+``BaseModel`` implementation of TPC. The tests are designed to run quickly
+without requiring real eICU or MIMIC-IV data.
+
+Overview:
+ The test suite checks:
+
+ 1. Model instantiation under a dataset-backed BaseModel contract.
+ 2. Forward-pass correctness and required output keys.
+ 3. Output shapes for scalar regression.
+ 4. Forward-pass behavior across major ablation variants.
+ 5. Gradient propagation through the model during backpropagation.
+ 6. Failure behavior when required batch fields are missing.
+
+Implementation Notes:
+ - Uses only small synthetic tensors for speed and reproducibility.
+ - Uses a minimal mocked dataset/processor contract to exercise the model's
+ true BaseModel-facing path.
+ - Does not depend on real datasets.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Dict, List
+
+import pytest
+import torch
+import torch.nn as nn
+
+import pyhealth.models.tpc as tpc_module
+from pyhealth.models.tpc import TPC
+
+
+@dataclass
+class DummyProcessor:
+ """Minimal processor stub exposing a schema with a ``value`` field."""
+
+ names: List[str]
+
+ def schema(self) -> List[str]:
+ """Return the processor schema."""
+ return self.names
+
+
+class DummyDataset:
+ """Minimal dataset stub matching the fields TPC expects from BaseModel.
+
+ This object is intentionally lightweight. It provides only the fields
+ needed by the rewritten TPC model and the patched BaseModel methods in
+ this test module.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the dummy dataset."""
+ self.input_schema = {
+ "time_series": "timeseries",
+ "static": "tensor",
+ }
+ self.output_schema = {
+ "target_los_hours": "regression",
+ }
+ self.input_processors: Dict[str, DummyProcessor] = {
+ "time_series": DummyProcessor(["value"]),
+ "static": DummyProcessor(["value"]),
+ }
+ self.output_processors = {}
+
+
+@pytest.fixture()
+def patch_basemodel(monkeypatch):
+ """Patch BaseModel helpers so TPC can be tested with synthetic inputs only.
+
+ The project rubric requires fast synthetic model tests, and the current
+ test environment does not rely on the full trainer/runtime stack.
+ This fixture patches the minimum BaseModel behaviors the TPC model uses:
+ - dataset-backed initialization
+ - output size resolution
+ - regression loss resolution
+ - prediction preparation
+ """
+
+ def fake_init(self, dataset):
+ nn.Module.__init__(self)
+ self.dataset = dataset
+ self.feature_keys = list(dataset.input_schema.keys())
+ self.label_keys = list(dataset.output_schema.keys())
+ self.__dict__["device"] = torch.device("cpu")
+
+ class DummyProcessor:
+ def size(self):
+ return 1
+
+ self.dataset.output_processors = {
+ self.label_keys[0]: DummyProcessor()
+ }
+
+ monkeypatch.setattr(tpc_module.BaseModel, "__init__", fake_init)
+ monkeypatch.setattr(tpc_module.BaseModel, "get_output_size", lambda self: 1)
+ monkeypatch.setattr(
+ tpc_module.BaseModel,
+ "get_loss_function",
+ lambda self: nn.MSELoss(),
+ )
+ monkeypatch.setattr(
+ tpc_module.BaseModel,
+ "prepare_y_prob",
+ lambda self, logit: logit,
+ )
+ monkeypatch.setattr(
+ tpc_module.BaseModel,
+ "device",
+ property(lambda self: torch.device("cpu")),
+ )
+
+
+def make_batch(
+ batch_size: int = 2,
+ seq_len: int = 6,
+ num_features: int = 5,
+ static_dim: int = 3,
+) -> Dict[str, torch.Tensor | tuple[torch.Tensor]]:
+ """Create a small synthetic batch for BaseModel-style testing.
+
+ The ``time_series`` tensor uses the task/model contract shape [B, T, 3F]
+ with interleaved [value, mask, decay] channels per feature.
+
+ Args:
+ batch_size: Batch size.
+ seq_len: Sequence length.
+ num_features: Number of base time-series features.
+ static_dim: Static feature dimension.
+
+ Returns:
+ A batch dictionary compatible with ``TPC.forward(**kwargs)``.
+ """
+ values = torch.randn(batch_size, seq_len, num_features)
+ masks = torch.randint(0, 2, (batch_size, seq_len, num_features)).float()
+ decay = torch.rand(batch_size, seq_len, num_features)
+
+ pieces = []
+ for feature_idx in range(num_features):
+ pieces.append(values[:, :, feature_idx].unsqueeze(-1))
+ pieces.append(masks[:, :, feature_idx].unsqueeze(-1))
+ pieces.append(decay[:, :, feature_idx].unsqueeze(-1))
+ time_series = torch.cat(pieces, dim=-1) # [B, T, 3F]
+
+ static = torch.randn(batch_size, static_dim)
+ target = torch.rand(batch_size, 1) * 10.0
+
+ return {
+ "time_series": (time_series,),
+ "static": (static,),
+ "target_los_hours": target,
+ }
+
+
+def test_tpc_init_requires_expected_feature_keys(patch_basemodel) -> None:
+ """Test that TPC instantiates with the expected dataset-backed contract."""
+ dataset = DummyDataset()
+
+ model = TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ )
+
+ assert model.label_key == "target_los_hours"
+ assert set(model.feature_keys) == {"time_series", "static"}
+
+
+def test_tpc_forward_returns_required_basemodel_keys(patch_basemodel) -> None:
+ """Test that forward returns the required BaseModel output dictionary."""
+ dataset = DummyDataset()
+ batch = make_batch()
+
+ model = TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ )
+
+ outputs = model(**batch)
+
+ assert set(outputs.keys()) == {"loss", "y_prob", "y_true", "logit"}
+ assert torch.is_tensor(outputs["loss"])
+ assert outputs["y_prob"].shape == (2, 1)
+ assert outputs["y_true"].shape == (2, 1)
+ assert outputs["logit"].shape == (2, 1)
+ assert torch.isfinite(outputs["logit"]).all()
+
+
+def test_tpc_forward_accepts_tensor_inputs_without_processor_tuple(
+ patch_basemodel,
+) -> None:
+ """Test that forward accepts raw tensors as well as processor tuples."""
+ dataset = DummyDataset()
+ batch = make_batch()
+
+ raw_batch = {
+ "time_series": batch["time_series"][0],
+ "static": batch["static"][0],
+ "target_los_hours": batch["target_los_hours"],
+ }
+
+ model = TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ )
+
+ outputs = model(**raw_batch)
+ assert outputs["logit"].shape == (2, 1)
+ assert torch.isfinite(outputs["logit"]).all()
+
+
+@pytest.mark.parametrize(
+ "model_kwargs",
+ [
+ {"use_temporal": True, "use_pointwise": True},
+ {"use_temporal": True, "use_pointwise": False},
+ {"use_temporal": False, "use_pointwise": True},
+ {"use_temporal": True, "use_pointwise": True, "shared_temporal": True},
+ {"use_temporal": True, "use_pointwise": True, "use_skip_connections": False},
+ ],
+)
+def test_tpc_ablation_variants_forward(
+ patch_basemodel,
+ model_kwargs: dict,
+) -> None:
+ """Test forward pass across major ablation variants."""
+ dataset = DummyDataset()
+ batch = make_batch()
+
+ model = TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ **model_kwargs,
+ )
+
+ outputs = model(**batch)
+ assert outputs["logit"].shape == (2, 1)
+ assert torch.isfinite(outputs["logit"]).all()
+ assert torch.isfinite(outputs["loss"]).all()
+
+
+def test_tpc_requires_at_least_one_branch(patch_basemodel) -> None:
+ """Test that TPC rejects configuration with both branches disabled."""
+ dataset = DummyDataset()
+
+ with pytest.raises(ValueError):
+ TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ use_temporal=False,
+ use_pointwise=False,
+ )
+
+
+def test_tpc_missing_label_raises(patch_basemodel) -> None:
+ """Test that missing label data raises a clear error."""
+ dataset = DummyDataset()
+ batch = make_batch()
+ batch.pop("target_los_hours")
+
+ model = TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ )
+
+ with pytest.raises(ValueError, match="Missing required label field"):
+ model(**batch)
+
+
+def test_tpc_backward_pass(patch_basemodel) -> None:
+ """Test that gradients propagate through the dataset-backed model path."""
+ dataset = DummyDataset()
+ batch = make_batch()
+
+ model = TPC(
+ dataset=dataset,
+ input_dim=5,
+ static_dim=3,
+ temporal_channels=4,
+ pointwise_channels=4,
+ num_layers=2,
+ kernel_size=3,
+ fc_dim=8,
+ )
+
+ outputs = model(**batch)
+ loss = outputs["loss"]
+ loss.backward()
+
+ has_grad = any(
+ parameter.grad is not None
+ for parameter in model.parameters()
+ if parameter.requires_grad
+ )
+ assert has_grad
\ No newline at end of file
diff --git a/tests/tasks/test_hourly_los.py b/tests/tasks/test_hourly_los.py
new file mode 100644
index 000000000..475ba428d
--- /dev/null
+++ b/tests/tasks/test_hourly_los.py
@@ -0,0 +1,289 @@
+"""
+Unit tests for the Hourly ICU length-of-stay (LoS) task.
+
+This module contains synthetic unit tests for validating the behavior of the
+``HourlyLOSEICU`` task implementation. The tests verify correct construction
+of hourly time-series features, target generation, and dataset-specific
+handling for both eICU- and MIMIC-IV-style inputs.
+
+Overview:
+ The test suite checks:
+
+ 1. Formal task schema:
+ - ``time_series`` is declared as a timeseries input
+ - ``static`` is declared as a tensor input
+ - ``target_los_hours`` is declared as a regression output
+
+ 2. Hourly time-series construction:
+ - Latest observation within each hour is retained
+ - Forward-filling of missing values
+ - Correct decay feature computation (0.75 ** j)
+
+ 3. Pre-ICU handling:
+ - Inclusion of pre-ICU observations during processing
+ - Proper cropping of pre-ICU rows after feature construction
+
+ 4. Sample generation:
+ - eICU-style patient processing (offset-based timestamps)
+ - MIMIC-IV-style patient processing (datetime-based timestamps)
+ - Presence and correctness of expected output fields
+
+Implementation Notes:
+ - Tests use lightweight synthetic data for speed and reproducibility.
+ - No dependency on real eICU or MIMIC-IV datasets.
+ - Designed to validate core preprocessing logic independent of model code.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+
+from pyhealth.tasks.hourly_los import HourlyLOSEICU
+
+
+class DummyEvent:
+ """Minimal event object used for task unit tests."""
+
+ def __init__(self, attr_dict, timestamp=None):
+ """Initialize the dummy event.
+
+ Args:
+ attr_dict: Event attribute dictionary.
+ timestamp: Optional event timestamp.
+ """
+ self.attr_dict = attr_dict
+ self.timestamp = timestamp
+
+
+class DummyPatient:
+ """Minimal patient object used for task unit tests."""
+
+ def __init__(self, patient_id, tables):
+ """Initialize the dummy patient.
+
+ Args:
+ patient_id: Patient identifier.
+ tables: Mapping from table name to event list.
+ """
+ self.patient_id = patient_id
+ self.tables = tables
+
+ def get_events(self, table):
+ """Return events for the requested table.
+
+ Args:
+ table: Table name.
+
+ Returns:
+ List of events for that table.
+ """
+ return self.tables.get(table, [])
+
+
+def test_hourly_los_declares_expected_schema() -> None:
+ """Test that the task declares the expected BaseModel-facing schema."""
+ task = HourlyLOSEICU(
+ time_series_tables=["lab"],
+ time_series_features={"lab": ["creatinine"]},
+ )
+
+ assert task.input_schema == {
+ "time_series": "tensor",
+ "static": "tensor",
+ }
+ assert task.output_schema == {
+ "target_los_hours": "regression",
+ }
+
+
+def test_make_hourly_tensor_keeps_latest_and_forward_fills() -> None:
+ """Test that the latest within-hour measurement is kept and gaps are filled."""
+ task = HourlyLOSEICU(
+ time_series_tables=["lab"],
+ time_series_features={"lab": ["creatinine"]},
+ min_history_hours=1,
+ max_hours=6,
+ )
+
+ observations = [
+ (0, 0, 1.0, 0.1),
+ (0, 0, 2.0, 0.9), # later in same hour, should win
+ (2, 0, 3.0, 2.2),
+ ]
+
+ time_series = task._make_hourly_tensor(
+ observations=observations,
+ usable_hours=4,
+ num_features=1,
+ )
+
+ assert time_series[0][0] == 2.0
+ assert time_series[0][1] == 1.0
+ assert time_series[1][0] == 2.0
+ assert time_series[1][1] == 0.0
+ assert time_series[2][0] == 3.0
+ assert time_series[2][1] == 1.0
+
+
+def test_make_hourly_tensor_decay_behavior() -> None:
+ """Test that decay follows the expected ``0.75 ** j`` rule."""
+ task = HourlyLOSEICU(
+ time_series_tables=["lab"],
+ time_series_features={"lab": ["creatinine"]},
+ min_history_hours=1,
+ max_hours=6,
+ )
+
+ observations = [
+ (0, 0, 5.0, 0.2),
+ ]
+
+ time_series = task._make_hourly_tensor(
+ observations=observations,
+ usable_hours=4,
+ num_features=1,
+ )
+
+ assert time_series[0][2] == 1.0
+ assert abs(time_series[1][2] - 0.75) < 1e-6
+ assert abs(time_series[2][2] - (0.75 ** 2)) < 1e-6
+ assert abs(time_series[3][2] - (0.75 ** 3)) < 1e-6
+
+
+def test_cropped_hourly_tensor_removes_pre_icu_rows() -> None:
+ """Test that pre-ICU rows are removed after extended-timeline fill."""
+ task = HourlyLOSEICU(
+ time_series_tables=["lab"],
+ time_series_features={"lab": ["creatinine"]},
+ min_history_hours=1,
+ max_hours=6,
+ pre_icu_hours=2,
+ )
+
+ observations = [
+ (0, 0, 10.0, -2.0),
+ (1, 0, 11.0, -1.0),
+ (2, 0, 12.0, 0.0),
+ ]
+
+ time_series = task._make_cropped_hourly_tensor(
+ observations=observations,
+ total_hours=3.0,
+ num_features=1,
+ )
+
+ assert len(time_series) == 3
+ assert time_series[0][0] == 12.0
+
+
+def test_eicu_patient_generates_samples() -> None:
+ """Test eICU-style patient sample generation."""
+ task = HourlyLOSEICU(
+ time_series_tables=["lab"],
+ time_series_features={"lab": ["creatinine"]},
+ numeric_static_features=["age"],
+ categorical_static_features=["gender"],
+ min_history_hours=2,
+ max_hours=6,
+ )
+
+ patient_event = DummyEvent(
+ {
+ "patientunitstayid": "stay1",
+ "unitdischargeoffset": 240.0,
+ "age": 65,
+ "gender": "Male",
+ "hospitaladmittime24": "08:00:00",
+ }
+ )
+
+ lab_events = [
+ DummyEvent(
+ {
+ "labname": "creatinine",
+ "labresult": 1.2,
+ "labresultoffset": 0.0,
+ }
+ ),
+ DummyEvent(
+ {
+ "labname": "creatinine",
+ "labresult": 1.4,
+ "labresultoffset": 120.0,
+ }
+ ),
+ ]
+
+ patient = DummyPatient(
+ patient_id="p1",
+ tables={
+ "patient": [patient_event],
+ "lab": lab_events,
+ },
+ )
+
+ samples = task(patient)
+
+ assert len(samples) > 0
+ sample = samples[0]
+
+ assert "time_series" in sample
+ assert "static" in sample
+ assert "target_los_hours" in sample
+ assert "target_los_sequence" in sample
+ assert isinstance(sample["time_series"], list)
+ assert isinstance(sample["static"], list)
+ assert isinstance(sample["target_los_hours"], float)
+ assert isinstance(sample["target_los_sequence"], list)
+ assert len(sample["target_los_sequence"]) == sample["history_hours"]
+
+
+def test_mimic_patient_generates_samples() -> None:
+ """Test MIMIC-IV-style patient sample generation."""
+ task = HourlyLOSEICU(
+ time_series_tables=["labevents"],
+ time_series_features={"labevents": ["creatinine"]},
+ min_history_hours=2,
+ max_hours=6,
+ pre_icu_hours=2,
+ )
+
+ intime = datetime(2020, 1, 1, 10, 0, 0)
+ outtime = intime + timedelta(hours=4)
+
+ icu_event = DummyEvent(
+ {
+ "hadm_id": "hadm1",
+ "stay_id": "stay1",
+ "outtime": outtime.isoformat(),
+ },
+ timestamp=intime,
+ )
+
+ lab_event = DummyEvent(
+ {
+ "hadm_id": "hadm1",
+ "label": "creatinine",
+ "valuenum": 1.5,
+ },
+ timestamp=intime + timedelta(hours=1),
+ )
+
+ patient = DummyPatient(
+ patient_id="p2",
+ tables={
+ "patients": [DummyEvent({})],
+ "admissions": [DummyEvent({"hadm_id": "hadm1"})],
+ "icustays": [icu_event],
+ "labevents": [lab_event],
+ },
+ )
+
+ samples = task(patient)
+
+ assert len(samples) > 0
+ sample = samples[0]
+ assert "time_series" in sample
+ assert "static" in sample
+ assert "target_los_hours" in sample
+ assert isinstance(sample["target_los_hours"], float)
\ No newline at end of file