Skip to content
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.patient_readmission.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.patient_readmission
=======================================

.. autoclass:: pyhealth.tasks.patient_readmission.PatientReadmissionPredictionEICU
:members:
:undoc-members:
:show-inheritance:
129 changes: 129 additions & 0 deletions pyhealth/tasks/patient_readmission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Dict, List
from pyhealth.data import Event, Patient
from pyhealth.tasks import BaseTask

class PatientReadmissionPredictionEICU(BaseTask):
"""
Readmission prediction on the eICU dataset.

This task aims at predicting whether the patient will be readmitted into the ICU
during the same hospital stay based on clinical information from the current ICU
visit.

Features:
- using diagnosis table (ICD9CM and ICD10CM) as condition codes
- using physicalexam table as procedure codes
- using medication table as drugs codes

Attributes:
task_name (str): The name of the task.
input_schema (Dict[str, str]): The schema for the task input.
output_schema (Dict[str, str]): The schema for the task output.

Examples:
>>> from pyhealth.datasets import eICUDataset
>>> from pyhealth.tasks import ReadmissionPredictionEICU
>>> dataset = eICUDataset(
... root="/path/to/eicu-crd/2.0",
... tables=["diagnosis", "medication", "physicalexam"],
... )
>>> task = ReadmissionPredictionEICU(exclude_minors=True)
>>> sample_dataset = dataset.set_task(task)
"""

task_name: str = "ReadmissionPredictionEICU"
input_schema: Dict[str, str] = {
"conditions": "sequence",
"procedures": "sequence",
"drugs": "sequence",
}
output_schema: Dict[str, str] = {"readmission": "binary"}

def __init__(self, exclude_minors: bool = True, **kwargs) -> None:
"""Initializes the task object.

Args:
exclude_minors: Whether to exclude patients whose age is
less than 18. Defaults to True.
**kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`, e.g.
``code_mapping``.
"""
super().__init__(**kwargs)
self.exclude_minors = exclude_minors

def __call__(self, patient: Patient) -> List[Dict]:
"""
Generates binary classification data samples for a single patient.

Args:
patient (Patient): A patient object.

Returns:
List[Dict]: A list containing a dictionary for each patient visit with:
- 'visit_id': eICU patientunitstayid.
- 'patient_id': eICU uniquepid.
- 'conditions': Diagnosis codes from diagnosis table.
- 'procedures': Physical exam codes from physicalexam table.
- 'drugs': Drug names from medication table.
- 'readmission': binary label (1 if readmitted, 0 otherwise).
"""
patient_stays = patient.get_events(event_type="patient")
if len(patient_stays) < 2:
return []
sorted_stays = sorted(
patient_stays,
key=lambda s: (
int(getattr(s, "patienthealthsystemstayid", 0) or 0),
int(getattr(s, "unitvisitnumber", 0) or 0),
),
)
samples = []
for i in range(len(sorted_stays) - 1):
stay = sorted_stays[i]
next_stay = sorted_stays[i + 1]
if self.exclude_minors:
try:
age_str = str(getattr(stay, "age", "0")).replace(">", "").strip()
if int(age_str) < 18:
continue
except (ValueError, TypeError):
pass
stay_id = str(getattr(stay, "patientunitstayid", ""))
diagnoses = patient.get_events(
event_type = "diagnosis", filters = [("patientunitstayid", "==", stay_id)]
)
conditions = [
getattr(event, "icd9code", "") for event in diagnoses
if getattr(event, "icd9code", None)
]
physical_exams = patient.get_events(
event_type = "physicalexam", filters = [("patientunitstayid", "==", stay_id)]
)
procedures = [
getattr(event, "physicalexampath", "") for event in physical_exams
if getattr(event, "physicalexampath", None)
]
medications = patient.get_events(
event_type = "medication", filters = [("patientunitstayid", "==", stay_id)]
)
drugs = [
getattr(event, "drugname", "") for event in medications
if getattr(event, "drugname", None)
]
if len(conditions) == 0 and len(procedures) == 0 and len(drugs) == 0:
continue
current_hosp_id = getattr(stay, "patienthealthsystemstayid", None)
next_hosp_id = getattr(next_stay, "patienthealthsystemstayid", None)
readmission = int(current_hosp_id == next_hosp_id and current_hosp_id is not None)
samples.append(
{
"visit_id": stay_id,
"patient_id": patient.patient_id,
"conditions": [conditions],
"procedures": [procedures],
"drugs": [drugs],
"readmission": readmission,
}
)

return samples
105 changes: 105 additions & 0 deletions tests/test_patient_readmission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
from datetime import datetime
from pyhealth.data import Patient, Event
from pyhealth.tasks import ReadmissionPredictionEICU

def test_readmission_prediction_eicu_task():
"""
Tests the ReadmissionPredictionEICU task using synthetic, in-memory patient data
to ensure tests complete in milliseconds.
"""
# 1. Initialize Task
task = ReadmissionPredictionEICU(exclude_minors = True)

# 2. Create a mock patient
patient = Patient(patient_id = "test_pat_001")

# Visit 1: ICU Stay 1 (in Hospital 1)
patient.add_event(Event(
event_type = "patient",
timestamp=datetime(2025, 1, 1),
patienthealthsystemstayid = "hosp_001",
patientunitstayid = "icu_001",
unitvisitnumber = 1,
age = "65"
))
patient.add_event(Event(
event_type = "diagnosis",
patientunitstayid = "icu_001",
icd9code = "428.0"
))
patient.add_event(Event(
event_type = "medication",
patientunitstayid = "icu_001",
drugname = "Aspirin"
))

# Visit 2: ICU Stay 2 (Readmitted to the SAME hospital, hosp_001)
patient.add_event(Event(
event_type = "patient",
timestamp = datetime(2025, 1, 10),
patienthealthsystemstayid = "hosp_001",
patientunitstayid = "icu_002",
unitvisitnumber = 2,
age = "65"
))
patient.add_event(Event(
event_type = "physicalexam",
patientunitstayid = "icu_002",
physicalexampath = "cardiovascular|murmur"
))

# Visit 3: ICU Stay 3 (Admitted to a DIFFERENT hospital, hosp_002)
patient.add_event(Event(
event_type = "patient",
timestamp = datetime(2025, 5, 1),
patienthealthsystemstayid = "hosp_002",
patientunitstayid = "icu_003",
unitvisitnumber = 1,
age = "65"
))
patient.add_event(Event(
event_type = "diagnosis",
patientunitstayid = "icu_003",
icd9code = "250.00"
))

# 3. Call the task
samples = task(patient)

# 4. Assertions
# With 3 ICU stays, we expect 2 samples (1->2, 2->3)
assert len(samples) == 2, "Task should generate exactly 2 samples"

# Check Sample 1 (icu_001 -> icu_002)
assert samples[0]["visit_id"] == "icu_001"
assert samples[0]["readmission"] == 1
assert "428.0" in samples[0]["conditions"][0]
assert "Aspirin" in samples[0]["drugs"][0]
assert len(samples[0]["procedures"][0]) == 0

# Check Sample 2 (icu_002 -> icu_003)
assert samples[1]["visit_id"] == "icu_002"
assert samples[1]["readmission"] == 0
assert "cardiovascular|murmur" in samples[1]["procedures"][0]

def test_exclude_minors():
"""Test that the task correctly excludes patients under 18."""
task = ReadmissionPredictionEICU(exclude_minors = True)
patient = Patient(patient_id = "test_minor")
for i in [1, 2]:
patient.add_event(Event(
event_type="patient",
timestamp=datetime(2025, 1, i),
patienthealthsystemstayid = "hosp_001",
patientunitstayid = f"icu_{i}",
unitvisitnumber = i,
age="10"
))
patient.add_event(Event(
event_type = "diagnosis",
patientunitstayid = f"icu_{i}",
icd9code = "test"
))
samples = task(patient)
assert len(samples) == 0, "Task should return 0 samples for minors when exclude_minors=True"