Skip to content

feat: add DuETT dual event-time transformer + ICU mortality task#1013

Open
shubhamx64 wants to merge 1 commit intosunlabuiuc:masterfrom
shubhamx64:feature/duett-model
Open

feat: add DuETT dual event-time transformer + ICU mortality task#1013
shubhamx64 wants to merge 1 commit intosunlabuiuc:masterfrom
shubhamx64:feature/duett-model

Conversation

@shubhamx64
Copy link
Copy Markdown

DuETT: Dual Event Time Transformer — Full Pipeline (Task + Model) for MIMIC-IV ICU Mortality

Contributor

  • Name: Shubham Srivastava
  • Email / NetID: ss253@illinois.edu
  • Course: CS 598 DL4H, Spring 2026 (UIUC)

Type of Contribution

Option 4: Full Pipeline (new task + new model, reusing the existing MIMIC4Dataset).

Paper

Labach, A.; Pokhrel, A.; Huang, X. S.; Zuberi, S.; Yi, S. E.; Volkovs, M.; Poutanen, T.; and Krishnan, R. G. 2023. DuETT: Dual Event Time Transformer for Electronic Health Records. In Proceedings of the 4th Machine Learning for Health Symposium, PMLR 219:295–315.

This PR implements DuETT natively in PyHealth.
No code is copied or wrapped from the released implementation.

What DuETT Does (High-Level)

Electronic health records are sparse, irregularly sampled multivariate time series: lab tests, vitals, and interventions each have their own cadence, and the absence of a measurement is itself informative. Standard sequence models flatten these observations into a single token stream and lose the (variable × time) structure.

DuETT instead treats the EHR as an explicit event-type × time matrix:

  1. Binning: Irregular observations are aggregated into uniform time windows. Each (variable, bin) cell retains both the mean value and an observation count, preserving the distinction between true zeros and missing measurements.
  2. Static fusion: Patient-level features (age, sex) are fused into the representation.
  3. Dual-axis attention: Alternating Transformer encoder layers attend across the event axis (cross-variable relationships at each timestep) and the time axis (temporal dynamics per variable).
  4. Pooling & classification: Three fusion options (rep_token, averaging, masked_embed) produce a patient-level embedding for downstream prediction.

Changes Overview

New files (7)

# File Purpose
1 pyhealth/tasks/icu_mortality_duett_mimic4.py Task: bins MIMIC-IV lab events into event-by-time tensors, extracts static features, produces ICU mortality label
2 pyhealth/models/duett.py Model: DuETTLayer (dual-axis attention core) + DuETT (BaseModel wrapper with classification head)
3 tests/core/test_duett.py 11 synthetic-data model tests: instantiation, forward pass, output shapes, gradient flow, multiclass, 3 fusion modes
4 tests/core/test_icu_mortality_duett_mimic4.py 8 synthetic-data task tests: schemas, instantiation, binning shape/values, edge cases
5 docs/api/models/pyhealth.models.DuETT.rst Model API docs
6 docs/api/tasks/pyhealth.tasks.ICUMortalityDuETTMIMIC4.rst Task API docs
7 examples/mortality_prediction/mortality_mimic4_duett.py End-to-end example + 6-config hyperparameter ablation

Modified files (4, index updates only)

# File Change
8 pyhealth/models/__init__.py Export DuETT, DuETTLayer
9 pyhealth/tasks/__init__.py Export ICUMortalityDuETTMIMIC4
10 docs/api/models.rst Add DuETT to toctree
11 docs/api/tasks.rst Add ICUMortalityDuETTMIMIC4 to toctree

File Guide — Review Order

Recommended review path:

  1. pyhealth/tasks/icu_mortality_duett_mimic4.py — task logic + binning
  2. pyhealth/models/duett.py — model architecture (read DuETTLayer first, then DuETT)
  3. tests/core/test_duett.py + tests/core/test_icu_mortality_duett_mimic4.py — coverage
  4. examples/mortality_prediction/mortality_mimic4_duett.py — end-to-end usage + ablation
  5. .rst docs + index updates

Testing

  • 19 unit tests (11 model + 8 task), all synthetic-data, entire suite completes in ~5s
  • Tests cover: instantiation, forward pass, output shapes, gradient flow, multiclass head, all 3 fusion methods, input/output schemas, binning shape correctness, binning value correctness, empty-patient edge case, lab categories enum, custom vs default hyperparameters
  • No real datasets used in tests (rubric requirement)

Run:

pytest tests/core/test_duett.py tests/core/test_icu_mortality_duett_mimic4.py -v

Ablation Study

The example script runs a 6-configuration ablation varying model capacity, regularization, and depth. Results below are from MIMIC-IV 3.1 with dev=True (~1000 patients → 816 samples after cohort filtering), 20 epochs, CPU, lr=1e-4, batch_size=64:

Configuration Params ROC-AUC PR-AUC
Small (d=64) 111,425 0.597 0.031
Medium (d=128) 435,841 0.922 0.143
Large (d=256) 1,723,649 0.805 0.063
Low dropout (0.1) 435,841 0.831 0.071
High dropout (0.5) 435,841 0.883 0.100
Deeper (2×2 layers) 832,385 0.792 0.059

Findings:

  • Capacity sweet spot is d=128. d=64 underfits (ROC-AUC 0.60), d=256 overfits (0.80). With only ~650 training patients the middle-capacity model wins decisively.
  • Default dropout 0.3 outperforms both 0.1 and 0.5. Under-regularized (0.1) drops to 0.83; over-regularized (0.5) drops to 0.88.
  • Depth: 1×1 is enough. Doubling to 2×2 layers overfits (0.79).
  • Absolute numbers are below the paper's full-MIMIC-IV results because of the dev-mode subset; the cross-configuration pattern (capacity sweet spot, dropout tuning direction) is consistent with the paper's reported behavior.

Environment: Local CPU training on Windows 11 (MIMIC_ROOT via env var, MIMIC_DEV=1).

Design Notes

  • Task does the binning — The task converts raw MIMIC-IV lab events into a fixed event-by-time tensor before the model sees them. The schema uses the existing TensorProcessor for values/counts/static/times and BinaryLabelProcessor for mortality — no new processor needed.
  • Model does not use EmbeddingModel — DuETT uses per-variable linear projections, which is integral to the architecture and incompatible with EmbeddingModel's single-projection approach. Documented in the class docstring.
  • Time-based positional encoding uses actual bin endpoint times via nn.Linear(1, d_embedding) projection (following the paper), not learned positional embeddings.
  • Native PyTorch only — uses nn.TransformerEncoderLayer (batch-first). No x-transformers or other external deps beyond what PyHealth already uses.

Out of scope

  • Self-supervised masked pretraining — mentioned in the paper as a separate stage, deferred.
  • PhysioNet/CinC 2012 benchmark — this PR focuses on MIMIC-IV; PhysioNet can be added in a follow-up.

Checklist

  • Rebased on latest master
  • 19 unit tests passing, < 10s total
  • Paper link in PR description
  • Example script under examples/
  • RST docs + index toctree entries
  • Google-style docstrings + type hints on all public methods
  • PEP8, 88-char line width
  • Native PyHealth implementation (no wrapping of authors' code)
  • "Allow edits by maintainers" enabled

Adds a full-pipeline contribution implementing DuETT (Labach et al.
2023, ML4H, PMLR 219) for MIMIC-IV ICU mortality prediction:

- pyhealth/tasks/icu_mortality_duett_mimic4.py: bins MIMIC-IV lab
  events into an event-by-time tensor, retains observation counts,
  emits the mortality label derived from hospital_expire_flag.
- pyhealth/models/duett.py: DuETTLayer (dual-axis attention over
  event and time dimensions) + DuETT(BaseModel) classifier.
- tests/core/test_duett.py and tests/core/test_icu_mortality_duett_mimic4.py:
  19 synthetic tests covering instantiation, forward pass, output
  shapes, gradient flow, schemas, binning correctness, and edge
  cases (~5s total).
- examples/mortality_prediction/mortality_mimic4_duett.py: end-to-end
  example with a 6-config hyperparameter ablation.
- RST docs + toctree entries in docs/api/models.rst, docs/api/tasks.rst,
  and the per-file docs/api/models/pyhealth.models.DuETT.rst and
  docs/api/tasks/pyhealth.tasks.ICUMortalityDuETTMIMIC4.rst pages.

Paper: https://proceedings.mlr.press/v219/labach23a.html
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant