diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..ee4a444f6 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -207,6 +207,7 @@ Available Tasks Base Task In-Hospital Mortality (MIMIC-IV) + In-Hospital Mortality Temporal (MIMIC-IV) MIMIC-III ICD-9 Coding Cardiology Detection COVID-19 CXR Classification diff --git a/docs/api/tasks/pyhealth.tasks.InHospitalMortalityTemporalMIMIC4.rst b/docs/api/tasks/pyhealth.tasks.InHospitalMortalityTemporalMIMIC4.rst new file mode 100644 index 000000000..53148305a --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.InHospitalMortalityTemporalMIMIC4.rst @@ -0,0 +1,7 @@ +RST: pyhealth.tasks.InHospitalMortalityTemporalMIMIC4 +==================================================== + +.. autoclass:: pyhealth.tasks.in_hospital_mortality_temporal_mimic4.InHospitalMortalityTemporalMIMIC4 + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/mortality_prediction/mortality_mimic4_temporal_emdot.py b/examples/mortality_prediction/mortality_mimic4_temporal_emdot.py new file mode 100644 index 000000000..1de1894d2 --- /dev/null +++ b/examples/mortality_prediction/mortality_mimic4_temporal_emdot.py @@ -0,0 +1,163 @@ +""" +EMDOT Temporal Evaluation for In-Hospital Mortality on MIMIC-IV + +Reproduces the EMDOT framework from Zhou et al. (CHIL 2023) to show that +random train/test splits overestimate performance vs temporal splits. + +We compare two training regimes across deployment years: + - All-historical: train on everything up to year t, test on the rest + - Sliding window: train on only the last `window` years before t + +We also ablate over window sizes (1, 2, 3, 5 years) to see how the +recency vs sample-size tradeoff plays out. All experiments use logistic +regression with bag-of-codes features. + +Run: + python examples/mortality_prediction/mortality_mimic4_temporal_emdot.py +""" + +import time +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import roc_auc_score +from sklearn.preprocessing import MultiLabelBinarizer + +from pyhealth.datasets import MIMIC4Dataset, split_by_sample +from pyhealth.tasks import InHospitalMortalityTemporalMIMIC4 + + +def temporal_split(samples, deployment_year, regime="all_historical", window=3): + # split into in-period (<=t) and out-period (>t) + in_period = [s for s in samples if s["admission_year"] <= deployment_year] + out_period = [s for s in samples if s["admission_year"] > deployment_year] + + if regime == "all_historical": + train = in_period + elif regime == "sliding_window": + train = [s for s in in_period + if s["admission_year"] >= deployment_year - window] + else: + raise ValueError(f"Unknown regime: {regime}") + return train, out_period + + +def encode_features(train_samples, test_samples): + # concat all codes into one list per sample, then binarize + def get_codes(s): + return s.get("conditions", []) + s.get("procedures", []) + s.get("drugs", []) + + mlb = MultiLabelBinarizer(sparse_output=False) + X_train = mlb.fit_transform([get_codes(s) for s in train_samples]) + X_test = mlb.transform([get_codes(s) for s in test_samples]) + y_train = np.array([s["mortality"] for s in train_samples]) + y_test = np.array([s["mortality"] for s in test_samples]) + return X_train, X_test, y_train, y_test + + +def run_emdot(samples, regime, window=3, min_train=50, + deployment_years=range(2012, 2020), seed=42): + # loop over deployment years, train LR, get AUROC + results = {} + for t in deployment_years: + train_samp, test_samp = temporal_split(samples, t, regime=regime, window=window) + + if len(train_samp) < min_train or len(test_samp) == 0: + print(f" t={t}: skipping (train={len(train_samp)}, test={len(test_samp)})") + continue + if len(set(s["mortality"] for s in train_samp)) < 2: + print(f" t={t}: skipping (only one class)") + continue + + X_train, X_test, y_train, y_test = encode_features(train_samp, test_samp) + lr = LogisticRegression(max_iter=1000, solver="lbfgs", + class_weight="balanced", random_state=seed) + lr.fit(X_train, y_train) + auroc = roc_auc_score(y_test, lr.predict_proba(X_test)[:, 1]) + results[t] = auroc + print(f" t={t} | train={len(train_samp):5d} | test={len(test_samp):5d} | " + f"pos_rate={y_test.mean():.3f} | AUROC={auroc:.4f}") + return results + + +if __name__ == "__main__": + t0 = time.perf_counter() + + # load dataset + print("=" * 60) + print("Step 1: Loading MIMIC-IV dataset") + print("=" * 60) + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + cache_dir="../benchmark_cache/mimic4_temporal/", + ) + base_dataset.stats() + + # apply task + print("\n" + "=" * 60) + print("Step 2: Applying temporal mortality task") + print("=" * 60) + task = InHospitalMortalityTemporalMIMIC4() + sample_dataset = base_dataset.set_task(task, num_workers=4) + all_samples = list(sample_dataset) + years = [s["admission_year"] for s in all_samples] + print(f"Total samples: {len(all_samples)}") + print(f"Year range: {min(years)}-{max(years)}") + print(f"Mortality rate: {np.mean([s['mortality'] for s in all_samples]):.3f}") + + # random split baseline + print("\n" + "=" * 60) + print("Step 3: Random split baseline (time-agnostic)") + print("=" * 60) + train_ds, _, test_ds = split_by_sample( + sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42 + ) + train_rand = [sample_dataset[i] for i in train_ds.indices] + test_rand = [sample_dataset[i] for i in test_ds.indices] + X_tr, X_te, y_tr, y_te = encode_features(train_rand, test_rand) + lr = LogisticRegression(max_iter=1000, solver="lbfgs", + class_weight="balanced", random_state=42) + lr.fit(X_tr, y_tr) + baseline_auroc = roc_auc_score(y_te, lr.predict_proba(X_te)[:, 1]) + print(f"Random split AUROC: {baseline_auroc:.4f}") + + # all-historical + print("\n" + "=" * 60) + print("Step 4: All-historical regime") + print("=" * 60) + results_ah = run_emdot(all_samples, regime="all_historical") + + # sliding window (default w=3) + print("\n" + "=" * 60) + print("Step 5: Sliding window regime (w=3)") + print("=" * 60) + results_sw = run_emdot(all_samples, regime="sliding_window") + + # window size ablation + # trying different window sizes to see how it affects things + print("\n" + "=" * 60) + print("Step 6: Window size ablation") + print("=" * 60) + ablation = {} + for w in [1, 2, 3, 5]: + print(f"\n --- w={w} ---") + ablation[w] = run_emdot(all_samples, regime="sliding_window", window=w) + + # print results + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + print(f"\nRandom split AUROC: {baseline_auroc:.4f}") + print(f"\nBy deployment year:") + print("Year\tAll-Hist\tSliding(w=3)") + for t in range(2012, 2020): + ah = results_ah.get(t, float('nan')) + sw = results_sw.get(t, float('nan')) + print(f"{t}\t{ah:.4f}\t\t{sw:.4f}") + + print(f"\nWindow ablation (mean AUROC):") + for w, res in ablation.items(): + if res: + print(f" w={w}: {np.mean(list(res.values())):.4f}") + + print(f"\nDone in {time.perf_counter() - t0:.1f}s") diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..8997d9c35 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -23,6 +23,7 @@ drug_recommendation_omop_fn, ) from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 +from .in_hospital_mortality_temporal_mimic4 import InHospitalMortalityTemporalMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, LengthOfStayPredictionMIMIC3, diff --git a/pyhealth/tasks/in_hospital_mortality_temporal_mimic4.py b/pyhealth/tasks/in_hospital_mortality_temporal_mimic4.py new file mode 100644 index 000000000..80cef9ece --- /dev/null +++ b/pyhealth/tasks/in_hospital_mortality_temporal_mimic4.py @@ -0,0 +1,125 @@ +from typing import Any, Dict, List + +from .base_task import BaseTask + + +class InHospitalMortalityTemporalMIMIC4(BaseTask): + """In-ICU mortality prediction on MIMIC-IV with temporal (EMDOT-style) evaluation. + + Each sample is tagged with its admission year so callers can partition data + chronologically to simulate real-world deployment conditions, following the + EMDOT framework (Zhou et al., 2023). Supports both all-historical and + sliding window training regimes. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for input data, which includes: + - conditions: A sequence of diagnosis ICD codes. + - procedures: A sequence of procedure ICD codes. + - drugs: A sequence of prescribed drug names. + output_schema (Dict[str, str]): The schema for output data, which includes: + - mortality: A binary indicator of in-hospital mortality. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import InHospitalMortalityTemporalMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/2.2", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = InHospitalMortalityTemporalMIMIC4() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "InHospitalMortalityTemporalMIMIC4" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"mortality": "binary"} + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Generates binary mortality samples tagged with admission year. + + Admissions with no conditions OR no procedures OR no drugs are excluded. + Patients under 18 years old (anchor_age) are excluded. + + Args: + patient (Any): A PyHealth Patient object. + + Returns: + List[Dict[str, Any]]: A list of dicts, each containing: + - 'patient_id': MIMIC-IV subject_id. + - 'admission_id': MIMIC-IV hadm_id. + - 'conditions': ICD codes from diagnoses_icd. + - 'procedures': ICD codes from procedures_icd. + - 'drugs': Drug names from prescriptions. + - 'mortality': binary label (1 = died in hospital, 0 = survived). + - 'admission_year': int year of admission for temporal splits. + """ + demographics = patient.get_events(event_type="patients") + assert len(demographics) == 1 + demo = demographics[0] + if int(demo.anchor_age) < 18: + return [] + + # compute date shift to recover real calendar years + # anchor_year_group is like "2017 - 2019", take midpoint + anchor_year = int(demo.anchor_year) + group = getattr(demo, "anchor_year_group", None) + if group and " - " in str(group): + parts = str(group).split(" - ") + real_anchor = (int(parts[0]) + int(parts[1])) // 2 + year_shift = anchor_year - real_anchor + else: + year_shift = 0 + + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + samples = [] + for admission in admissions: + filter = ("hadm_id", "==", admission.hadm_id) + + conditions = [] + for event in patient.get_events( + event_type="diagnoses_icd", filters=[filter] + ): + assert event.icd_version in ("9", "10") + conditions.append(f"{event.icd_version}_{event.icd_code}") + if len(conditions) == 0: + continue + + procedures = [] + for event in patient.get_events( + event_type="procedures_icd", filters=[filter] + ): + assert event.icd_version in ("9", "10") + procedures.append(f"{event.icd_version}_{event.icd_code}") + if len(procedures) == 0: + continue + + prescriptions = patient.get_events( + event_type="prescriptions", filters=[filter] + ) + drugs = [event.drug for event in prescriptions] + if len(drugs) == 0: + continue + + if admission.hospital_expire_flag is None: + continue + mortality = int(admission.hospital_expire_flag) + + samples.append({ + "patient_id": patient.patient_id, + "admission_id": admission.hadm_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "mortality": mortality, + "admission_year": admission.timestamp.year - year_shift, + }) + + return samples \ No newline at end of file diff --git a/tests/core/test_mimic4_mortality_temporal.py b/tests/core/test_mimic4_mortality_temporal.py new file mode 100644 index 000000000..0f9376413 --- /dev/null +++ b/tests/core/test_mimic4_mortality_temporal.py @@ -0,0 +1,329 @@ +"""Tests for InHospitalMortalityTemporalMIMIC4 task. + +Uses synthetic mock patients (not demo/real data). Each test builds +dummy Patient objects and calls the task directly. +""" + +import unittest +from collections import Counter +from datetime import datetime + +from pyhealth.tasks import InHospitalMortalityTemporalMIMIC4 + + +# -- mock classes to simulate PyHealth Patient/Event objects -- + +class DummyDemographics: + def __init__(self, anchor_age, anchor_year=2150, + anchor_year_group="2017 - 2019"): + self.anchor_age = anchor_age + self.anchor_year = anchor_year + self.anchor_year_group = anchor_year_group + + +class DummyAdmission: + def __init__(self, hadm_id, timestamp, hospital_expire_flag): + self.hadm_id = hadm_id + self.timestamp = timestamp + self.hospital_expire_flag = hospital_expire_flag + + +class DummyDiagnosis: + def __init__(self, hadm_id, icd_code, icd_version): + self.hadm_id = hadm_id + self.icd_code = icd_code + self.icd_version = icd_version + + +class DummyProcedure: + def __init__(self, hadm_id, icd_code, icd_version): + self.hadm_id = hadm_id + self.icd_code = icd_code + self.icd_version = icd_version + + +class DummyPrescription: + def __init__(self, hadm_id, drug): + self.hadm_id = hadm_id + self.drug = drug + + +class DummyPatient: + def __init__(self, patient_id, demographics, admissions, + diagnoses, procedures, prescriptions): + self.patient_id = patient_id + self._demographics = demographics + self._admissions = admissions + self._diagnoses = diagnoses + self._procedures = procedures + self._prescriptions = prescriptions + + def get_events(self, event_type, filters=None, **kwargs): + if event_type == "patients": + return self._demographics + elif event_type == "admissions": + return self._admissions + elif event_type == "diagnoses_icd": + events = self._diagnoses + elif event_type == "procedures_icd": + events = self._procedures + elif event_type == "prescriptions": + events = self._prescriptions + else: + return [] + + # apply hadm_id filter if provided + if filters: + for field, op, value in filters: + if op == "==": + events = [ + e for e in events + if getattr(e, field, None) == value + ] + return events + + +# -- helpers to build patients quickly -- + +def make_patient(patient_id, age, admissions_data, + anchor_year=2150, + anchor_year_group="2017 - 2019"): + """Build a DummyPatient from a compact spec. + + admissions_data is a list of dicts like: + {"hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], "px": [("5A19", "10")], + "drugs": ["Insulin"]} + """ + demographics = [DummyDemographics( + age, anchor_year, anchor_year_group + )] + admissions = [] + diagnoses = [] + procedures = [] + prescriptions = [] + + for a in admissions_data: + ts = datetime(a["year"], 3, 15) + flag = "1" if a.get("died", False) else "0" + admissions.append( + DummyAdmission(a["hadm_id"], ts, flag) + ) + for code, ver in a.get("dx", []): + diagnoses.append( + DummyDiagnosis(a["hadm_id"], code, ver) + ) + for code, ver in a.get("px", []): + procedures.append( + DummyProcedure(a["hadm_id"], code, ver) + ) + for drug in a.get("drugs", []): + prescriptions.append( + DummyPrescription(a["hadm_id"], drug) + ) + + return DummyPatient( + patient_id, demographics, admissions, + diagnoses, procedures, prescriptions, + ) + + +class TestTemporalMortalityMIMIC4(unittest.TestCase): + + def setUp(self): + self.task = InHospitalMortalityTemporalMIMIC4() + + # -- schema -- + + def test_task_name_and_schemas(self): + t = InHospitalMortalityTemporalMIMIC4 + self.assertEqual( + t.task_name, "InHospitalMortalityTemporalMIMIC4" + ) + self.assertEqual(t.input_schema, { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + }) + self.assertEqual(t.output_schema, {"mortality": "binary"}) + + # -- sample processing -- + + def test_basic_sample_keys(self): + patient = make_patient("p1", 45, [{ + "hadm_id": "100", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Metformin"], + }]) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + s = samples[0] + for k in ["patient_id", "admission_id", "conditions", + "procedures", "drugs", "mortality", + "admission_year"]: + self.assertIn(k, s) + + def test_multiple_admissions(self): + patient = make_patient("p1", 50, [ + {"hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], "px": [("5A19", "10")], + "drugs": ["Insulin"]}, + {"hadm_id": "2", "year": 2152, "died": True, + "dx": [("I50", "10")], "px": [("02HV", "10")], + "drugs": ["Furosemide"]}, + ]) + samples = self.task(patient) + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["admission_id"], "1") + self.assertEqual(samples[1]["admission_id"], "2") + + # -- label generation -- + + def test_mortality_label_died(self): + patient = make_patient("p1", 60, [{ + "hadm_id": "1", "year": 2150, "died": True, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertEqual(samples[0]["mortality"], 1) + + def test_mortality_label_survived(self): + patient = make_patient("p1", 60, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertEqual(samples[0]["mortality"], 0) + + # -- feature extraction -- + + def test_admission_year_deshifted(self): + # anchor_year=2150, group="2017 - 2019" -> midpoint 2018 + # shift = 2150 - 2018 = 132 + # admission in shifted year 2151 -> real year 2151 - 132 = 2019 + patient = make_patient("p1", 40, [{ + "hadm_id": "1", "year": 2151, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }], anchor_year=2150, anchor_year_group="2017 - 2019") + samples = self.task(patient) + self.assertEqual(samples[0]["admission_year"], 2019) + + def test_icd_codes_have_version_prefix(self): + patient = make_patient("p1", 40, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10"), ("4019", "9")], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertIn("10_E10", samples[0]["conditions"]) + self.assertIn("9_4019", samples[0]["conditions"]) + + def test_drugs_are_drug_names(self): + patient = make_patient("p1", 40, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Insulin", "Metformin"], + }]) + samples = self.task(patient) + self.assertEqual( + samples[0]["drugs"], ["Insulin", "Metformin"] + ) + + # -- edge cases -- + + def test_minor_excluded(self): + patient = make_patient("p1", 17, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertEqual(samples, []) + + def test_age_18_included(self): + patient = make_patient("p1", 18, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + + def test_no_diagnoses_skipped(self): + patient = make_patient("p1", 40, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [], + "px": [("5A19", "10")], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertEqual(samples, []) + + def test_no_procedures_skipped(self): + patient = make_patient("p1", 40, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [], + "drugs": ["Insulin"], + }]) + samples = self.task(patient) + self.assertEqual(samples, []) + + def test_no_drugs_skipped(self): + patient = make_patient("p1", 40, [{ + "hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], + "px": [("5A19", "10")], + "drugs": [], + }]) + samples = self.task(patient) + self.assertEqual(samples, []) + + def test_no_admissions(self): + demographics = [DummyDemographics(40)] + patient = DummyPatient("p1", demographics, [], [], [], []) + samples = self.task(patient) + self.assertEqual(samples, []) + + def test_missing_expire_flag_skipped(self): + # admission with hospital_expire_flag=None should be skipped + demographics = [DummyDemographics(40)] + adm = DummyAdmission("1", datetime(2150, 1, 1), None) + dx = DummyDiagnosis("1", "E10", "10") + px = DummyProcedure("1", "5A19", "10") + rx = DummyPrescription("1", "Insulin") + patient = DummyPatient( + "p1", demographics, [adm], [dx], [px], [rx] + ) + samples = self.task(patient) + self.assertEqual(samples, []) + + def test_mixed_valid_and_invalid_admissions(self): + # one admission with all features, one missing procedures + patient = make_patient("p1", 50, [ + {"hadm_id": "1", "year": 2150, "died": False, + "dx": [("E10", "10")], "px": [("5A19", "10")], + "drugs": ["Insulin"]}, + {"hadm_id": "2", "year": 2151, "died": False, + "dx": [("I10", "10")], "px": [], + "drugs": ["Lisinopril"]}, + ]) + samples = self.task(patient) + # only the first admission should produce a sample + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["admission_id"], "1") + + +if __name__ == "__main__": + unittest.main()