From 198528bbd297476b82f115385f019114de00642b Mon Sep 17 00:00:00 2001 From: dylan-g12 Date: Sun, 19 Apr 2026 12:26:55 -0500 Subject: [PATCH 1/4] Added patient_readmission.py and test_patient_readmission.py --- .github/patient_readmission.py | 129 ++++++++++++++++++++++++++++ .github/test_patient_readmission.py | 105 ++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 .github/patient_readmission.py create mode 100644 .github/test_patient_readmission.py diff --git a/.github/patient_readmission.py b/.github/patient_readmission.py new file mode 100644 index 000000000..c6292028f --- /dev/null +++ b/.github/patient_readmission.py @@ -0,0 +1,129 @@ +from typing import Dict, List +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask + +class ReadmissionPredictionEICU(BaseTask): + """ + Readmission prediction on the eICU dataset. + + This task aims at predicting whether the patient will be readmitted into the ICU + during the same hospital stay based on clinical information from the current ICU + visit. + + Features: + - using diagnosis table (ICD9CM and ICD10CM) as condition codes + - using physicalexam table as procedure codes + - using medication table as drugs codes + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import eICUDataset + >>> from pyhealth.tasks import ReadmissionPredictionEICU + >>> dataset = eICUDataset( + ... root="/path/to/eicu-crd/2.0", + ... tables=["diagnosis", "medication", "physicalexam"], + ... ) + >>> task = ReadmissionPredictionEICU(exclude_minors=True) + >>> sample_dataset = dataset.set_task(task) + """ + + task_name: str = "ReadmissionPredictionEICU" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"readmission": "binary"} + + def __init__(self, exclude_minors: bool = True, **kwargs) -> None: + """Initializes the task object. + + Args: + exclude_minors: Whether to exclude patients whose age is + less than 18. Defaults to True. + **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`, e.g. + ``code_mapping``. + """ + super().__init__(**kwargs) + self.exclude_minors = exclude_minors + + def __call__(self, patient: Patient) -> List[Dict]: + """ + Generates binary classification data samples for a single patient. + + Args: + patient (Patient): A patient object. + + Returns: + List[Dict]: A list containing a dictionary for each patient visit with: + - 'visit_id': eICU patientunitstayid. + - 'patient_id': eICU uniquepid. + - 'conditions': Diagnosis codes from diagnosis table. + - 'procedures': Physical exam codes from physicalexam table. + - 'drugs': Drug names from medication table. + - 'readmission': binary label (1 if readmitted, 0 otherwise). + """ + patient_stays = patient.get_events(event_type="patient") + if len(patient_stays) < 2: + return [] + sorted_stays = sorted( + patient_stays, + key=lambda s: ( + int(getattr(s, "patienthealthsystemstayid", 0) or 0), + int(getattr(s, "unitvisitnumber", 0) or 0), + ), + ) + samples = [] + for i in range(len(sorted_stays) - 1): + stay = sorted_stays[i] + next_stay = sorted_stays[i + 1] + if self.exclude_minors: + try: + age_str = str(getattr(stay, "age", "0")).replace(">", "").strip() + if int(age_str) < 18: + continue + except (ValueError, TypeError): + pass + stay_id = str(getattr(stay, "patientunitstayid", "")) + diagnoses = patient.get_events( + event_type = "diagnosis", filters = [("patientunitstayid", "==", stay_id)] + ) + conditions = [ + getattr(event, "icd9code", "") for event in diagnoses + if getattr(event, "icd9code", None) + ] + physical_exams = patient.get_events( + event_type = "physicalexam", filters = [("patientunitstayid", "==", stay_id)] + ) + procedures = [ + getattr(event, "physicalexampath", "") for event in physical_exams + if getattr(event, "physicalexampath", None) + ] + medications = patient.get_events( + event_type = "medication", filters = [("patientunitstayid", "==", stay_id)] + ) + drugs = [ + getattr(event, "drugname", "") for event in medications + if getattr(event, "drugname", None) + ] + if len(conditions) == 0 and len(procedures) == 0 and len(drugs) == 0: + continue + current_hosp_id = getattr(stay, "patienthealthsystemstayid", None) + next_hosp_id = getattr(next_stay, "patienthealthsystemstayid", None) + readmission = int(current_hosp_id == next_hosp_id and current_hosp_id is not None) + samples.append( + { + "visit_id": stay_id, + "patient_id": patient.patient_id, + "conditions": [conditions], + "procedures": [procedures], + "drugs": [drugs], + "readmission": readmission, + } + ) + + return samples \ No newline at end of file diff --git a/.github/test_patient_readmission.py b/.github/test_patient_readmission.py new file mode 100644 index 000000000..d85921979 --- /dev/null +++ b/.github/test_patient_readmission.py @@ -0,0 +1,105 @@ +import pytest +from datetime import datetime +from pyhealth.data import Patient, Event +from pyhealth.tasks import ReadmissionPredictionEICU + +def test_readmission_prediction_eicu_task(): + """ + Tests the ReadmissionPredictionEICU task using synthetic, in-memory patient data + to ensure tests complete in milliseconds. + """ + # 1. Initialize Task + task = ReadmissionPredictionEICU(exclude_minors = True) + + # 2. Create a mock patient + patient = Patient(patient_id = "test_pat_001") + + # Visit 1: ICU Stay 1 (in Hospital 1) + patient.add_event(Event( + event_type = "patient", + timestamp=datetime(2025, 1, 1), + patienthealthsystemstayid = "hosp_001", + patientunitstayid = "icu_001", + unitvisitnumber = 1, + age = "65" + )) + patient.add_event(Event( + event_type = "diagnosis", + patientunitstayid = "icu_001", + icd9code = "428.0" + )) + patient.add_event(Event( + event_type = "medication", + patientunitstayid = "icu_001", + drugname = "Aspirin" + )) + + # Visit 2: ICU Stay 2 (Readmitted to the SAME hospital, hosp_001) + patient.add_event(Event( + event_type = "patient", + timestamp = datetime(2025, 1, 10), + patienthealthsystemstayid = "hosp_001", + patientunitstayid = "icu_002", + unitvisitnumber = 2, + age = "65" + )) + patient.add_event(Event( + event_type = "physicalexam", + patientunitstayid = "icu_002", + physicalexampath = "cardiovascular|murmur" + )) + + # Visit 3: ICU Stay 3 (Admitted to a DIFFERENT hospital, hosp_002) + patient.add_event(Event( + event_type = "patient", + timestamp = datetime(2025, 5, 1), + patienthealthsystemstayid = "hosp_002", + patientunitstayid = "icu_003", + unitvisitnumber = 1, + age = "65" + )) + patient.add_event(Event( + event_type = "diagnosis", + patientunitstayid = "icu_003", + icd9code = "250.00" + )) + + # 3. Call the task + samples = task(patient) + + # 4. Assertions + # With 3 ICU stays, we expect 2 samples (1->2, 2->3) + assert len(samples) == 2, "Task should generate exactly 2 samples" + + # Check Sample 1 (icu_001 -> icu_002) + assert samples[0]["visit_id"] == "icu_001" + assert samples[0]["readmission"] == 1 + assert "428.0" in samples[0]["conditions"][0] + assert "Aspirin" in samples[0]["drugs"][0] + assert len(samples[0]["procedures"][0]) == 0 + + # Check Sample 2 (icu_002 -> icu_003) + assert samples[1]["visit_id"] == "icu_002" + assert samples[1]["readmission"] == 0 + assert "cardiovascular|murmur" in samples[1]["procedures"][0] + +def test_exclude_minors(): + """Test that the task correctly excludes patients under 18.""" + task = ReadmissionPredictionEICU(exclude_minors = True) + patient = Patient(patient_id = "test_minor") + for i in [1, 2]: + patient.add_event(Event( + event_type="patient", + timestamp=datetime(2025, 1, i), + patienthealthsystemstayid = "hosp_001", + patientunitstayid = f"icu_{i}", + unitvisitnumber = i, + age="10" + )) + patient.add_event(Event( + event_type = "diagnosis", + patientunitstayid = f"icu_{i}", + icd9code = "test" + )) + samples = task(patient) + assert len(samples) == 0, "Task should return 0 samples for minors when exclude_minors=True" \ No newline at end of file From c980c2d7e10e9fe1e4595da2dd3d9697bd6b5aef Mon Sep 17 00:00:00 2001 From: dylan-g12 Date: Sun, 19 Apr 2026 12:53:10 -0500 Subject: [PATCH 2/4] Changed file locations of patient_readmission.py from .github to PyHealth/pyhealth/tasks and test_patient_readmission.py from .github to PyHealth/tests --- {.github => pyhealth/tasks}/patient_readmission.py | 0 {.github => tests}/test_patient_readmission.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {.github => pyhealth/tasks}/patient_readmission.py (100%) rename {.github => tests}/test_patient_readmission.py (100%) diff --git a/.github/patient_readmission.py b/pyhealth/tasks/patient_readmission.py similarity index 100% rename from .github/patient_readmission.py rename to pyhealth/tasks/patient_readmission.py diff --git a/.github/test_patient_readmission.py b/tests/test_patient_readmission.py similarity index 100% rename from .github/test_patient_readmission.py rename to tests/test_patient_readmission.py From 16269b5852e330eb40d908404b6bb45c115059e5 Mon Sep 17 00:00:00 2001 From: dylan-g12 Date: Sun, 19 Apr 2026 13:00:27 -0500 Subject: [PATCH 3/4] Added pyhealth.tasks.patient_readmission.rst and fixed name of class in patient_readmission.py to not copy class names --- docs/api/tasks/pyhealth.tasks.patient_readmission.rst | 4 ++++ pyhealth/tasks/patient_readmission.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 docs/api/tasks/pyhealth.tasks.patient_readmission.rst diff --git a/docs/api/tasks/pyhealth.tasks.patient_readmission.rst b/docs/api/tasks/pyhealth.tasks.patient_readmission.rst new file mode 100644 index 000000000..7a460216a --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.patient_readmission.rst @@ -0,0 +1,4 @@ +.. autoclass:: pyhealth.tasks.patient_readmission.PatientReadmissionPredictionEICU + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/pyhealth/tasks/patient_readmission.py b/pyhealth/tasks/patient_readmission.py index c6292028f..928935258 100644 --- a/pyhealth/tasks/patient_readmission.py +++ b/pyhealth/tasks/patient_readmission.py @@ -2,7 +2,7 @@ from pyhealth.data import Event, Patient from pyhealth.tasks import BaseTask -class ReadmissionPredictionEICU(BaseTask): +class PatientReadmissionPredictionEICU(BaseTask): """ Readmission prediction on the eICU dataset. From fc2eff420c474b09927b6fdce13b165d1928b2d6 Mon Sep 17 00:00:00 2001 From: dylan-g12 Date: Sun, 19 Apr 2026 13:01:48 -0500 Subject: [PATCH 4/4] Fix: pyhealth.tasks.patient_readmission.rst was missing necessary lines pyhealth.tasks.BaseTask ======================================= --- docs/api/tasks/pyhealth.tasks.patient_readmission.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/api/tasks/pyhealth.tasks.patient_readmission.rst b/docs/api/tasks/pyhealth.tasks.patient_readmission.rst index 7a460216a..d28177b31 100644 --- a/docs/api/tasks/pyhealth.tasks.patient_readmission.rst +++ b/docs/api/tasks/pyhealth.tasks.patient_readmission.rst @@ -1,3 +1,6 @@ +pyhealth.tasks.patient_readmission +======================================= + .. autoclass:: pyhealth.tasks.patient_readmission.PatientReadmissionPredictionEICU :members: :undoc-members: