Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ API Reference
models/pyhealth.models.SafeDrug
models/pyhealth.models.MoleRec
models/pyhealth.models.Deepr
models/pyhealth.models.DuETT
models/pyhealth.models.EHRMamba
models/pyhealth.models.JambaEHR
models/pyhealth.models.ContraWR
Expand Down
14 changes: 14 additions & 0 deletions docs/api/models/pyhealth.models.DuETT.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pyhealth.models.DuETT
===================================

The separate callable DuETTLayer and the complete DuETT model.

.. autoclass:: pyhealth.models.DuETTLayer
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.models.DuETT
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ Available Tasks
COVID-19 CXR Classification <tasks/pyhealth.tasks.COVID19CXRClassification>
DKA Prediction (MIMIC-IV) <tasks/pyhealth.tasks.dka>
Drug Recommendation <tasks/pyhealth.tasks.drug_recommendation>
ICU Mortality DuETT (MIMIC-IV) <tasks/pyhealth.tasks.ICUMortalityDuETTMIMIC4>
Length of Stay Prediction <tasks/pyhealth.tasks.length_of_stay_prediction>
Medical Transcriptions Classification <tasks/pyhealth.tasks.MedicalTranscriptionsClassification>
Mortality Prediction (Next Visit) <tasks/pyhealth.tasks.mortality_prediction>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.ICUMortalityDuETTMIMIC4.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.ICUMortalityDuETTMIMIC4
=======================================

.. autoclass:: pyhealth.tasks.icu_mortality_duett_mimic4.ICUMortalityDuETTMIMIC4
:members:
:undoc-members:
:show-inheritance:
255 changes: 255 additions & 0 deletions examples/mortality_prediction/mortality_mimic4_duett.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""DuETT for ICU Mortality Prediction on MIMIC-IV with Ablation Study.

This example demonstrates:
1. Loading MIMIC-IV data with the DuETT-specific mortality task
2. Training the DuETT model for ICU mortality prediction
3. Running an ablation study over model hyperparameters
4. Comparing ROC-AUC and PR-AUC across configurations

Paper: Labach et al. 2023. DuETT: Dual Event Time Transformer for
Electronic Health Records. ML4H 2023, PMLR 219:295-315.
https://proceedings.mlr.press/v219/labach23a.html

Ablation Study Design:
- Primary: Vary d_embedding (64, 128, 256), dropout (0.1, 0.3, 0.5),
and layer depth (1x1, 2x2) to compare ROC-AUC and PR-AUC.
- Secondary: Vary n_time_bins (12, 24, 48) for preprocessing.

Reported Results (MIMIC-IV 3.1, dev=True ~1000-patient subset, 816
patient samples; 20 epochs, CPU, lr=1e-4, batch=64):

Configuration Params ROC-AUC PR-AUC
--------------------------------------------------
Small (d=64) 111,425 0.5974 0.0312
Medium (d=128) 435,841 0.9221 0.1429 <-- best
Large (d=256) 1,723,649 0.8052 0.0625
Low dropout (0.1) 435,841 0.8312 0.0714
High dropout (0.5) 435,841 0.8831 0.1000
Deeper (2x2 layers) 832,385 0.7922 0.0588

Findings:
- Capacity sweet spot is d=128; d=64 underfits and d=256 overfits
with only ~650 training samples.
- Default dropout 0.3 outperforms both 0.1 (too weak) and 0.5 (too
aggressive) on this subset.
- Depth 1x1 is sufficient; 2x2 overfits at this data scale.
- Absolute numbers are lower than the paper's full-MIMIC-IV results
because this uses the dev=True ~1000-patient subset; the pattern
across configurations matches paper-reported capacity behavior.

Usage:
# Set the path to your MIMIC-IV root, then run:
export MIMIC_ROOT=/path/to/mimic-iv/3.1
python mortality_mimic4_duett.py

# Full-data mode (disable the dev=True subset of ~1000 patients):
export MIMIC_DEV=0
python mortality_mimic4_duett.py

Notes:
- Requires MIMIC-IV access (PhysioNet credentialing).
- Use MIMIC-IV demo for testing: physionet.org/content/mimic-iv-demo/2.2/
- Designed for GPU (RTX 4060 Ti / Colab T4). CPU works but is slower.
- By default runs in dev mode (1000-patient subset) for tractability.
"""

import os

import torch
from pyhealth.datasets import (
MIMIC4Dataset,
get_dataloader,
split_by_patient,
)
from pyhealth.models import DuETT
from pyhealth.tasks import ICUMortalityDuETTMIMIC4
from pyhealth.trainer import Trainer


def run_experiment(
sample_dataset,
train_dataset,
val_dataset,
test_dataset,
config,
device="cuda" if torch.cuda.is_available() else "cpu",
):
"""Train and evaluate a single DuETT configuration.

Args:
sample_dataset: Full SampleDataset (for model init).
train_dataset: Training split.
val_dataset: Validation split.
test_dataset: Test split.
config: Dict with d_embedding, n_event_layers, n_time_layers,
n_heads, dropout, fusion_method.
device: Compute device.

Returns:
Dict with config name and test metrics.
"""
train_loader = get_dataloader(
train_dataset, batch_size=64, shuffle=True
)
val_loader = get_dataloader(
val_dataset, batch_size=64, shuffle=False
)
test_loader = get_dataloader(
test_dataset, batch_size=64, shuffle=False
)

model = DuETT(
dataset=sample_dataset,
d_embedding=config["d_embedding"],
n_event_layers=config["n_event_layers"],
n_time_layers=config["n_time_layers"],
n_heads=config["n_heads"],
dropout=config["dropout"],
fusion_method=config.get("fusion_method", "rep_token"),
)

num_params = sum(p.numel() for p in model.parameters())
print(f"\n Config: {config['name']}")
print(f" Parameters: {num_params:,}")

trainer = Trainer(
model=model,
device=device,
metrics=["pr_auc", "roc_auc"],
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=20,
monitor="roc_auc",
optimizer_params={"lr": 1e-4},
)

results = trainer.evaluate(test_loader)
results["config"] = config["name"]
return results


if __name__ == "__main__":
# ---- Configuration ----
# Set the MIMIC_ROOT environment variable to your MIMIC-IV directory,
# or update the default path below.
MIMIC_ROOT = os.environ.get("MIMIC_ROOT", "/path/to/mimic-iv/2.2")
DEV_MODE = os.environ.get("MIMIC_DEV", "1") == "1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N_TIME_BINS = 24
INPUT_WINDOW_HOURS = 48

print("=" * 60)
print("DuETT ICU Mortality Prediction - Ablation Study")
print("=" * 60)

# ---- Step 1: Load MIMIC-IV ----
print(f"\n[1/4] Loading MIMIC-IV dataset (dev_mode={DEV_MODE})...")
base_dataset = MIMIC4Dataset(
ehr_root=MIMIC_ROOT,
ehr_tables=["patients", "admissions", "labevents"],
dev=DEV_MODE,
)

# ---- Step 2: Apply DuETT mortality task ----
print("\n[2/4] Applying DuETT mortality prediction task...")
task = ICUMortalityDuETTMIMIC4(
n_time_bins=N_TIME_BINS,
input_window_hours=INPUT_WINDOW_HOURS,
)
sample_dataset = base_dataset.set_task(task)
print(f" Total samples: {len(sample_dataset)}")

# ---- Step 3: Split dataset ----
print("\n[3/4] Splitting dataset (80/10/10)...")
train_ds, val_ds, test_ds = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)
print(f" Train: {len(train_ds)}, Val: {len(val_ds)}, "
f"Test: {len(test_ds)}")

# ---- Step 4: Ablation Study ----
print("\n[4/4] Running ablation study...")

# Define ablation configurations
configs = [
{
"name": "Small (d=64)",
"d_embedding": 64,
"n_event_layers": 1,
"n_time_layers": 1,
"n_heads": 4,
"dropout": 0.3,
},
{
"name": "Medium (d=128)",
"d_embedding": 128,
"n_event_layers": 1,
"n_time_layers": 1,
"n_heads": 4,
"dropout": 0.3,
},
{
"name": "Large (d=256)",
"d_embedding": 256,
"n_event_layers": 1,
"n_time_layers": 1,
"n_heads": 4,
"dropout": 0.3,
},
{
"name": "Low dropout (0.1)",
"d_embedding": 128,
"n_event_layers": 1,
"n_time_layers": 1,
"n_heads": 4,
"dropout": 0.1,
},
{
"name": "High dropout (0.5)",
"d_embedding": 128,
"n_event_layers": 1,
"n_time_layers": 1,
"n_heads": 4,
"dropout": 0.5,
},
{
"name": "Deeper (2x2 layers)",
"d_embedding": 128,
"n_event_layers": 2,
"n_time_layers": 2,
"n_heads": 4,
"dropout": 0.3,
},
]

all_results = []
for config in configs:
try:
result = run_experiment(
sample_dataset, train_ds, val_ds, test_ds,
config, device=DEVICE,
)
all_results.append(result)
print(f" ROC-AUC: {result.get('roc_auc', 'N/A'):.4f}, "
f"PR-AUC: {result.get('pr_auc', 'N/A'):.4f}")
except Exception as e:
print(f" FAILED: {e}")
all_results.append({"config": config["name"], "error": str(e)})

# ---- Print Results Table ----
print("\n" + "=" * 60)
print("ABLATION RESULTS")
print("=" * 60)
print(f"{'Configuration':<25} {'ROC-AUC':>10} {'PR-AUC':>10}")
print("-" * 45)
for r in all_results:
if "error" in r:
print(f"{r['config']:<25} {'ERROR':>10} {'ERROR':>10}")
else:
roc = r.get("roc_auc", 0.0)
pr = r.get("pr_auc", 0.0)
print(f"{r['config']:<25} {roc:>10.4f} {pr:>10.4f}")
print("=" * 60)
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .concare import ConCare, ConCareLayer
from .contrawr import ContraWR, ResBlock2D
from .deepr import Deepr, DeeprLayer
from .duett import DuETT, DuETTLayer
from .embedding import EmbeddingModel
from .gamenet import GAMENet, GAMENetLayer
from .jamba_ehr import JambaEHR, JambaLayer
Expand Down
Loading