diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..709fdb274 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Dynamic Survival Analysis diff --git a/docs/api/tasks/pyhealth.tasks.dynamic_survival.rst b/docs/api/tasks/pyhealth.tasks.dynamic_survival.rst new file mode 100644 index 000000000..16ec59d94 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.dynamic_survival.rst @@ -0,0 +1,68 @@ +DynamicSurvivalTask +================== + +This module implements a dynamic survival analysis task for early event prediction. + +The task follows the anchor-based discrete-time survival formulation proposed in: + +Yèche et al. (2024), *Dynamic Survival Analysis for Early Event Prediction*. + +Key Features +------------ +- Multiple anchors per patient +- Discrete-time hazard prediction +- Support for censoring +- Configurable observation windows and anchor strategies + +Output Format +------------- +Each processed sample contains: + +- **patient_id**: unique patient identifier +- **visit_id**: unique anchor-based visit ID +- **x**: input features (temporal sequence) +- **y**: hazard label vector (0/1) +- **mask**: indicates valid risk set: + - 1 = patient is at risk at this timestep + - 0 = timestep excluded (post-event or post-censoring) + +Usage Example +------------- + +.. code-block:: python + + from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask + + # Minimal dataset wrapper (MockDataset or a real PyHealth dataset) + class MockDataset: + def __init__(self): + self.patients = {} + + dataset = MockDataset() + + task = DynamicSurvivalTask( + dataset=dataset, + observation_window=24, + horizon=24, + anchor_strategy="fixed", + ) + + # Apply to a patient object + samples = task(patient) + +Example Output +-------------- + +Each sample: + +- x: shape (T, d) +- y: shape (horizon,) +- mask: shape (horizon,) + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.dynamic_survival.DynamicSurvivalTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/dynamic_survival_ablation.py b/examples/dynamic_survival_ablation.py new file mode 100644 index 000000000..e485fe13d --- /dev/null +++ b/examples/dynamic_survival_ablation.py @@ -0,0 +1,164 @@ +# Authors: Skyler Lehto (lehto2@illinois.edu), +# Ryan Bradley (ryancb3@illinois.edu), +# Weonah Choi (weonahc2@illinois.edu) +# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024) +# Link: https://arxiv.org/abs/2403.12818 +# Description: Ablation study for observation window size on synthetic patients. + +""" +Ablation Study: Effect of Observation Window Length + +We vary observation window sizes (12, 24, 48 hours) and +measure performance using masked BCE and MSE. + +This demonstrates how task configuration (NOT model complexity) +impacts predictive performance. +""" + +import sys +import os +from datetime import datetime, timedelta + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import numpy as np +import torch +import torch.nn as nn + +from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask +from examples.synthetic_dataset import generate_synthetic_dataset +from examples.mock_ehr import MockEvent, MockVisit, MockPatient, MockDataset + + +# Convert synthetic dict → MockPatient + +def convert_to_mock_patients(patients_dict): + base_time = datetime(2025, 1, 1) + + mock_patients = [] + + for p in patients_dict: + visits_data = [] + + for v in p["visits"]: + visits_data.append({ + "time": base_time + timedelta(days=v["time"]), + "diagnosis": ["0000"], # dummy code for vocab + }) + + death_time = None + if p.get("outcome_time") is not None: + death_time = base_time + timedelta(days=p["outcome_time"]) + + mock_patients.append( + MockPatient( + pid=p["patient_id"], + visits_data=visits_data, + death_time=death_time, + ) + ) + + return mock_patients + + +# Model + +class SimpleModel(nn.Module): + def __init__(self, input_dim=2, hidden_dim=8, horizon=24): + super().__init__() + self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True) + self.fc = nn.Linear(hidden_dim, horizon) + + def forward(self, x): + _, h = self.rnn(x) + return torch.sigmoid(self.fc(h.squeeze(0))) + + +# Utils + +def prepare_batch(samples): + X, Y, M = [], [], [] + + for s in samples: + X.append(s["x"]) + Y.append(s["y"]) + M.append(s["mask"]) + + if len(X) == 0: + raise ValueError("No valid samples generated.") + + max_len = max(len(x) for x in X) + + X_pad = [] + for x in X: + pad = np.zeros((max_len - len(x), x.shape[1])) + X_pad.append(np.vstack([x, pad])) + + return ( + torch.tensor(np.array(X_pad), dtype=torch.float32), + torch.tensor(np.array(Y), dtype=torch.float32), + torch.tensor(np.array(M), dtype=torch.float32), + ) + + +def train_and_eval(samples): + model = SimpleModel() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + X, Y, M = prepare_batch(samples) + + for _ in range(5): + pred = model(X) + loss = -(Y * torch.log(pred + 1e-8) + + (1 - Y) * torch.log(1 - pred + 1e-8)) + loss = (loss * M).sum() / M.sum() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + pred = model(X) + + bce = -(Y * torch.log(pred + 1e-8) + + (1 - Y) * torch.log(1 - pred + 1e-8)) + bce = (bce * M).sum() / M.sum() + + mse = ((pred - Y) ** 2 * M).sum() / M.sum() + + return {"bce": bce.item(), "mse": mse.item()} + + +# Main Experiment + +def main(): + patients_raw = generate_synthetic_dataset(50) + patients = convert_to_mock_patients(patients_raw) + dataset = MockDataset(patients) + + windows = [12, 24, 48] + results = {} + + print("\n=== Ablation Results ===") + + for w in windows: + task = DynamicSurvivalTask( + dataset=dataset, + observation_window=w, + horizon=24, + ) + + samples = dataset.set_task(task) + + if len(samples) == 0: + print(f"Skipping window={w}, no samples") + continue + + score = train_and_eval(samples) + results[w] = score + + print(f"Window={w} | BCE={score['bce']:.4f} | MSE={score['mse']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/examples/mimic_dynamic_survival_gru.py b/examples/mimic_dynamic_survival_gru.py new file mode 100644 index 000000000..ba70a34b7 --- /dev/null +++ b/examples/mimic_dynamic_survival_gru.py @@ -0,0 +1,525 @@ +# Authors: Skyler Lehto (lehto2@illinois.edu), +# Ryan Bradley (ryancb3@illinois.edu), +# Weonah Choi (weonahc2@illinois.edu) +# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024) +# Link: https://arxiv.org/abs/2403.12818 +# Description: GRU-based ablation study over anchor strategy, window size, and horizon. + +""" +Ablation Study for Dynamic Survival Task + +We evaluate: +- Anchor strategy (fixed vs single) +- Observation window size +- Prediction horizon + +We use a GRU model on synthetic patients. + +Findings: +- Anchor strategy affects performance (fixed generally outperforms single). +- Observation window size shows no consistent monotonic trend. +- Prediction horizon changes task difficulty and performance. + +Results are printed to show how task configurations affect model performance. + +NOTE: +- Evaluation includes BCE, AuPRC, and C-index. +- C-index is computed if scikit-survival is available (optional dependency). +""" + +import random +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn + +from sklearn.metrics import average_precision_score + +try: + from sksurv.metrics import concordance_index_censored + HAS_SKSURV = True +except ImportError: + HAS_SKSURV = False + +from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask + +# use import if running on real MIMIC +# from pyhealth.datasets import MIMIC3Dataset + +from examples.synthetic_dataset import generate_synthetic_dataset +from examples.mock_ehr import MockEvent, MockVisit, MockPatient, MockDataset + + +np.random.seed(42) +torch.manual_seed(42) +random.seed(42) + + +# Synthetic Patient Generator + +def generate_synthetic_patients(n=20, seed=42): + """Generate synthetic MockPatient objects for experiments. + + Args: + n: Number of patients to generate. + seed: Random seed for reproducibility. + + Returns: + List of MockPatient objects with randomized visits and death times. + """ + random.seed(seed) + base_time = datetime(2025, 4, 1) + + patients = [] + + for i in range(n): + + num_visits = random.randint(5, 10) + visit_times = sorted(random.sample(range(1, 40), num_visits)) + + visits_data = [ + { + "time": base_time + timedelta(days=t), + "diagnosis": [str(random.randint(1000, 9999))] + } + for t in visit_times + ] + + if random.random() < 0.5: + death_time = base_time + timedelta( + days=max(visit_times) + random.randint(5, 15) + ) + else: + death_time = None + + patients.append( + MockPatient( + pid=f"P{i}", + visits_data=visits_data, + death_time=death_time, + ) + ) + + return patients + + +# Model + +class GRUModel(nn.Module): + def __init__(self, input_dim, hidden_dim=32, horizon=24): + super().__init__() + self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True) + self.fc = nn.Linear(hidden_dim, horizon) + + def forward(self, x): + _, h = self.rnn(x) + return torch.sigmoid(self.fc(h.squeeze(0))) + + +# Batch Function + +def prepare_batch(samples): + X, Y, M = [], [], [] + + for s in samples: + x = np.array(s["x"], dtype=np.float32) + y = np.array(s["y"], dtype=np.float32) + m = np.array(s["mask"], dtype=np.float32) + + X.append(x) + Y.append(y) + M.append(m) + + max_len = max(len(x) for x in X) + + X_pad = [] + for x in X: + pad = np.zeros((max_len - len(x), x.shape[1])) + X_pad.append(np.vstack([x, pad])) + + return ( + torch.tensor(np.array(X_pad), dtype=torch.float32), + torch.tensor(np.array(Y), dtype=torch.float32), + torch.tensor(np.array(M), dtype=torch.float32), + ) + + +# Prior Estimation Utilities for Bias Initialization(DSA Extension) + +def data_prior(samples): + """ + Estimate event probability using Maximum Likelihood Estimation (MLE). + + This computes the empirical event rate over all valid timesteps + (i.e., where mask > 0). + + Args: + samples (list): List of sample dictionaries containing "y" and "mask". + + Returns: + float: Estimated probability of event occurrence. + """ + total = 0 + events = 0 + + for s in samples: + y = np.array(s["y"], dtype=np.float32) + m = np.array(s["mask"], dtype=np.float32) + + valid = m > 0 + events += y[valid].sum() + total += valid.sum() + + return events / total if total > 0 else 0.0 + + +def gaussian_sample_prior(mean=-2.0, std=0.5): + """ + Sample a prior probability from a Gaussian distribution in logit space. + + The sampled value is transformed via the sigmoid function to obtain + a valid probability in (0, 1). + + Args: + mean (float): Mean of the Gaussian (in logit space). + std (float): Standard deviation. + + Returns: + float: Sampled probability. + """ + z = np.random.normal(mean, std) + p = 1 / (1 + np.exp(-z)) # sigmoid transform + return float(p) + + +def bayesian_posterior_prior(samples): + """ + Estimate event probability using a Bayesian posterior with Beta(1,1) prior. + + This corresponds to a Laplace-smoothed estimate of the event rate. + + Args: + samples (list): List of sample dictionaries containing "y" and "mask". + + Returns: + float: Posterior mean probability. + """ + total = 0 + events = 0 + + for s in samples: + y = np.array(s["y"], dtype=np.float32) + m = np.array(s["mask"], dtype=np.float32) + + valid = m > 0 + events += y[valid].sum() + total += valid.sum() + + alpha = events + 1.0 + beta = (total - events) + 1.0 + + return alpha / (alpha + beta) + + +# Training Loop + +def train_model(samples, horizon, prior=None): + """Train a GRU model on survival samples using masked BCE loss. + + Args: + samples: List of sample dicts with keys "x", "y", "mask". + horizon: Prediction horizon (sets the output dimension). + prior: Optional event rate prior for Bayesian bias initialization. + + Returns: + Trained GRUModel instance. + """ + X, Y, M = prepare_batch(samples) + + model = GRUModel(input_dim=X.shape[-1], horizon=horizon) + + # Bayesian initialization + if prior is not None: + p = prior + bias_init = torch.log(torch.tensor(p / (1 - p))) + + with torch.no_grad(): + model.fc.bias.fill_(bias_init) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + for epoch in range(5): + pred = model(X) + + loss = -(Y * torch.log(pred + 1e-8) + + (1 - Y) * torch.log(1 - pred + 1e-8)) + loss = (loss * M).sum() / M.sum() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return model + + +# Evaluation + +def evaluate(model, samples): + """Print masked BCE and MSE for a trained model on the given samples. + + Args: + model: Trained GRUModel. + samples: List of sample dicts with keys "x", "y", "mask". + """ + X, Y, M = prepare_batch(samples) + + with torch.no_grad(): + pred = model(X) + + bce = -(Y * torch.log(pred + 1e-8) + + (1 - Y) * torch.log(1 - pred + 1e-8)) + bce = (bce * M).sum() / M.sum() + + mse = ((pred - Y) ** 2 * M).sum() / M.sum() + + print(f"\nFinal Performance → BCE={bce.item():.4f} | MSE={mse.item():.4f}") + +def evaluate_3metrics(model, samples): + """Compute BCE, AuPRC, and C-index for a trained model. + + Args: + model: Trained GRUModel. + samples: List of sample dicts with keys "x", "y", "mask". + + Returns: + Tuple of (bce, auprc, cindex). auprc and cindex are None if + insufficient events exist to compute them. + """ + X, Y, M = prepare_batch(samples) + + with torch.no_grad(): + pred = model(X) + + # BCE + loss = -(Y * torch.log(pred + 1e-8) + + (1 - Y) * torch.log(1 - pred + 1e-8)) + bce = (loss * M).sum() / M.sum() + + # AuPRC + y_true = Y[M > 0].cpu().numpy().flatten() + y_pred = pred[M > 0].cpu().numpy().flatten() + auprc = None if y_true.sum() == 0 else average_precision_score(y_true, y_pred) + + # C-index + times, risks, events = [], [], [] + + for i in range(len(Y)): + y_i = Y[i].cpu().numpy() + pred_i = pred[i].cpu().numpy() + m_i = M[i].cpu().numpy() + + event_idx = np.where(y_i > 0)[0] + valid_idx = np.where(m_i > 0)[0] + if len(valid_idx) == 0: + # Skip samples with no usable data (fully zeroed mask). + continue + + if len(event_idx) > 0: + # y is one-hot by construction (generate_survival_label sets exactly one + # index), so event_idx always has one element and [0] is the event time. + event_time = event_idx[0] + observed = True + else: + # No event within horizon, use last unmasked step as the censoring time + # (i.e., last time we know the patient was event-free). + event_time = valid_idx[-1] + observed = False + + cumulative_risk = float(1.0 - np.prod(1.0 - pred_i)) + times.append(event_time) + risks.append(cumulative_risk) + events.append(observed) + + events_arr = np.array(events, dtype=bool) + times_arr = np.array(times) + risks_arr = np.array(risks) + + if len(times_arr) < 2 or not events_arr.any(): + cindex = None + try: + from sksurv.metrics import concordance_index_censored + + event_times = times_arr[events_arr] + + if (~events_arr).any(): + other_times = times_arr[~events_arr] + else: + other_times = times_arr[events_arr] + + if not (event_times.min() < other_times.max()): + cindex = None + else: + result = concordance_index_censored(events_arr, times_arr, risks_arr) + cindex = result[0] if isinstance(result, tuple) else result.concordance + + except ImportError: + cindex = None + except Exception: + cindex = None + return bce.item(), auprc, cindex + + +# RUN EXPERIMENT FUNCTION + +def run_experiment(dataset, horizon, window, anchor): + """Run one training/evaluation trial for a given task configuration. + + Args: + dataset: MockDataset instance. + horizon: Prediction horizon in time steps. + window: Observation window size in days. + anchor: Anchor strategy ("fixed" or "single"). + + Returns: + Tuple of (bce, mse), or None if no samples were generated. + """ + task = DynamicSurvivalTask( + dataset, + horizon=horizon, + observation_window=window, + anchor_strategy=anchor + ) + + samples = dataset.set_task(task) + + if len(samples) == 0: + return None + + model = train_model(samples, horizon=horizon) + + X, Y, M = prepare_batch(samples) + + with torch.no_grad(): + pred = model(X) + + bce = -(Y * torch.log(pred + 1e-8) + + (1 - Y) * torch.log(1 - pred + 1e-8)) + bce = (bce * M).sum() / M.sum() + + mse = ((pred - Y) ** 2 * M).sum() / M.sum() + + return bce.item(), mse.item() + + +def main(): + # Use synthetic patients so this script runs without a local MIMIC download. + # To run on real MIMIC-III, replace with: + # dataset = MIMIC3Dataset(root="", tables=[...], dev=True) + patients = generate_synthetic_patients(20) + dataset = MockDataset(patients) + + # 1. Anchor Ablation + anchors = ["fixed", "single"] + bce_list, auprc_list, cindex_list = [], [], [] + + for anchor in anchors: + task = DynamicSurvivalTask( + dataset, horizon=30, observation_window=12, anchor_strategy=anchor + ) + samples = dataset.set_task(task) + if not samples: + print(f"{anchor} → no samples generated, skipping") + continue + model = train_model(samples, 30) + bce, auprc, cidx = evaluate_3metrics(model, samples) + bce_list.append(bce) + auprc_list.append(auprc) + cindex_list.append(cidx) + + df_anchor = pd.DataFrame({ + "Anchor": anchors, "BCE": bce_list, + "AuPRC": auprc_list, "C-index": cindex_list + }) + print("\n=== Results ===") + print("(C-index shown only if scikit-survival is installed)\n") + print("\n=== Anchor Ablation Results ===") + print(df_anchor.round(4)) + + # 2. Window Ablation + windows = [6, 12, 24] + bce_list, auprc_list, cindex_list = [], [], [] + + for w in windows: + task = DynamicSurvivalTask( + dataset, horizon=10, observation_window=w, anchor_strategy="fixed" + ) + samples = dataset.set_task(task) + if not samples: + print(f"window={w} → no samples, skipping") + continue + model = train_model(samples, 10) + bce, auprc, cidx = evaluate_3metrics(model, samples) + bce_list.append(bce) + auprc_list.append(auprc) + cindex_list.append(cidx) + + df_window = pd.DataFrame({ + "Window": windows, "BCE": bce_list, + "AuPRC": auprc_list, "C-index": cindex_list + }) + print("\n=== Window Ablation Results ===") + print(df_window.round(4)) + + # 3. Horizon Ablation + horizons = [5, 10, 20] + bce_list, auprc_list, cindex_list = [], [], [] + + for h in horizons: + task = DynamicSurvivalTask( + dataset, horizon=h, observation_window=12, anchor_strategy="fixed" + ) + samples = dataset.set_task(task) + if not samples: + print(f"horizon={h} → no samples, skipping") + continue + model = train_model(samples, h) + bce, auprc, cidx = evaluate_3metrics(model, samples) + bce_list.append(bce) + auprc_list.append(auprc) + cindex_list.append(cidx) + + df_horizon = pd.DataFrame({ + "Horizon": horizons, "BCE": bce_list, + "AuPRC": auprc_list, "C-index": cindex_list + }) + print("\n=== Horizon Ablation Results ===") + print(df_horizon.round(4)) + + # 4. Prior Ablation + task = DynamicSurvivalTask( + dataset, horizon=10, observation_window=12, anchor_strategy="fixed", + ) + samples = dataset.set_task(task) + priors = { + "No Prior(MLE)": None, + "Data": data_prior(samples), + "Bayesian": bayesian_posterior_prior(samples), + } + + results = [] + + for name, prior in priors.items(): + model = train_model(samples, horizon=10, prior=prior) + bce, auprc, cidx = evaluate_3metrics(model, samples) + results.append((name, bce, auprc, cidx)) + + df_prior = pd.DataFrame(results, columns=["Prior", "BCE", "AuPRC", "C-index"]) + + print("\n=== Prior Ablation Results (Extension)===") + print(df_prior.round(4)) + + +if __name__ == "__main__": + main() diff --git a/examples/mock_ehr.py b/examples/mock_ehr.py new file mode 100644 index 000000000..f5175c8f1 --- /dev/null +++ b/examples/mock_ehr.py @@ -0,0 +1,95 @@ +# Authors: Skyler Lehto (lehto2@illinois.edu), +# Ryan Bradley (ryancb3@illinois.edu), +# Weonah Choi (weonahc2@illinois.edu) +# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024) +# Link: https://arxiv.org/abs/2403.12818 +# Description: Shared mock EHR classes used by dynamic survival example scripts. + +"""Lightweight mock EHR objects for dynamic survival example scripts. + +These classes stand in for real PyHealth patient/visit/event objects so +the example and ablation scripts can run without a local MIMIC download. +Imported by dynamic_survival_ablation.py and mimic_dynamic_survival_gru.py. +""" + + +class MockEvent: + """A single coded EHR event (diagnosis, procedure, or prescription).""" + + def __init__(self, code, timestamp, vocabulary): + """ + Args: + code: Clinical code string (e.g. ICD-9, NDC). + timestamp: Event datetime. + vocabulary: Vocabulary name (e.g. "ICD9CM"). + """ + self.code = code + self.timestamp = timestamp + self.vocabulary = vocabulary + + +class MockVisit: + """A single patient visit containing diagnosis events.""" + + def __init__(self, time, diagnosis=None): + """ + Args: + time: Visit datetime used as encounter_time. + diagnosis: Optional list of ICD-9 diagnosis code strings. + """ + self.encounter_time = time + self.event_list_dict = { + "DIAGNOSES_ICD": [ + MockEvent(c, time, "ICD9CM") for c in (diagnosis or []) + ], + "PROCEDURES_ICD": [], + "PRESCRIPTIONS": [], + } + + +class MockPatient: + """A patient with an ordered dict of MockVisit objects.""" + + def __init__(self, pid, visits_data, death_time=None): + """ + Args: + pid: Unique patient identifier string. + visits_data: List of dicts passed as kwargs to MockVisit. + death_time: Optional datetime of death; None if censored. + """ + self.patient_id = pid + self.visits = { + f"v{i}": MockVisit(**v) for i, v in enumerate(visits_data) + } + self.death_datetime = death_time + + +class MockDataset: + """Minimal dataset wrapper that applies a task to all patients.""" + + def __init__(self, patients): + """ + Args: + patients: List of MockPatient objects. + """ + self.patients = {p.patient_id: p for p in patients} + + def set_task(self, task): + """Apply task to every patient and return the collected samples. + + Called the same way as a real PyHealth dataset so example scripts + are structurally identical to production usage. + + Args: + task: A callable task (e.g. DynamicSurvivalTask) that accepts + a patient object and returns a list of sample dicts. + + Returns: + List of sample dicts from all patients combined. + """ + samples = [] + for p in self.patients.values(): + out = task(p) + if out: + samples.extend(out) + return samples diff --git a/examples/synthetic_dataset.py b/examples/synthetic_dataset.py new file mode 100644 index 000000000..21ade984f --- /dev/null +++ b/examples/synthetic_dataset.py @@ -0,0 +1,61 @@ +# Authors: Skyler Lehto (lehto2@illinois.edu), +# Ryan Bradley (ryancb3@illinois.edu), +# Weonah Choi (weonahc2@illinois.edu) +# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024) +# Link: https://arxiv.org/abs/2403.12818 +# Description: Synthetic EHR patient generator for testing and ablation experiments. + +""" +Generates synthetic EHR-like patient trajectories for testing. + +Each patient contains: +- visits: list of timestamped events +- outcome_time OR censor_time +""" + +import random + + +def generate_synthetic_dataset(num_patients=50, seed=None): + """ + Generates a stochastic synthetic dataset for experiments/ablations. + + Characteristics: + - Randomized visit times + - Random event vs censoring + - More realistic variability than test dataset + + Args: + num_patients (int): number of patients + seed (int, optional): random seed for reproducibility + + Returns: + List of patient dicts, each with "patient_id", "visits", + "outcome_time", and "censor_time" keys. + """ + if seed is not None: + random.seed(seed) + + patients = [] + + for i in range(num_patients): + num_events = random.randint(5, 15) + times = sorted(random.sample(range(1, 100), num_events)) + + visits = [{"time": t} for t in times] + + if random.random() > 0.5: + outcome_time = times[-1] + random.randint(5, 20) + censor_time = None + else: + outcome_time = None + censor_time = times[-1] + random.randint(5, 20) + + patients.append({ + "patient_id": f"p{i}", + "visits": visits, + "outcome_time": outcome_time, + "censor_time": censor_time, + }) + + return patients diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..2050454d9 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -14,6 +14,7 @@ from .covid19_cxr_classification import COVID19CXRClassification from .deid_ner import DeIDNERTask from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 +from .dynamic_survival import DynamicSurvivalTask from .drug_recommendation import ( DrugRecommendationEICU, DrugRecommendationMIMIC3, diff --git a/pyhealth/tasks/dynamic_survival.py b/pyhealth/tasks/dynamic_survival.py new file mode 100644 index 000000000..cbffbb105 --- /dev/null +++ b/pyhealth/tasks/dynamic_survival.py @@ -0,0 +1,752 @@ +# Authors: Skyler Lehto (lehto2@illinois.edu), +# Ryan Bradley (ryancb3@illinois.edu), +# Weonah Choi (weonahc2@illinois.edu) +# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024) +# Link: https://arxiv.org/abs/2403.12818 +# Description: Anchor-based discrete-time survival task for longitudinal EHR data. + +""" +Dynamic Survival Task for PyHealth. + +This module implements a dynamic survival prediction task using: +- Anchor-based sampling +- Observation windows +- Discrete-time survival labels + +The task converts longitudinal EHR data into sequence samples +for survival modeling. +""" + +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union, Tuple, Type + + +import numpy as np +from pyhealth.medcode import CrossMap +from pyhealth.tasks.base_task import BaseTask + + +DIAG_MAPPER = CrossMap("ICD9CM", "CCSCM") +PROC_MAPPER = CrossMap("ICD9PROC", "CCSPROC") +DRUG_MAPPER = CrossMap("NDC", "ATC") + + +def build_daily_time_series_from_df(patient) -> List[Dict[str, Any]]: + """ + Build daily time series from a dataframe-based patient. + + This function handles the PyHealth main dataset structure where + all events are stored in a single dataframe (patient.data_source). + It extracts relevant medical events and aggregates them into + daily time steps. + + Called by :func:`build_daily_time_series` when the patient object + exposes a ``data_source`` attribute. + + Args: + patient (Any): A PyHealth patient object with a dataframe + stored in `patient.data_source`. + + Returns: + List[Dict[str, Any]]: A list of daily aggregated visits where + each entry contains: + - time (int): day index + - diagnosis (List[str]) + - procedure (List[str]) + - drug (List[str]) + """ + df = patient.data_source + + events = [] + + for row in df.iter_rows(named=True): + timestamp = row.get("timestamp") + event_type = str(row.get("event_type")).lower() + + # Extract code based on event type + if event_type == "diagnoses_icd": + code = row.get("diagnoses_icd/icd9_code") + + elif event_type == "procedures_icd": + code = row.get("procedures_icd/icd9_code") + + elif event_type == "prescriptions": + code = row.get("prescriptions/ndc") + + else: + # Ignore non-medical tables (patients, admissions, icustays) + continue + + # Skip invalid rows + if timestamp is None or code is None: + continue + + events.append((timestamp, code, event_type)) + + if not events: + return [] + + # Sort events by time + events.sort(key=lambda x: x[0]) + first_time = events[0][0] + + # Map time -> codes + time_to_codes = defaultdict( + lambda: {"diagnosis": set(), "procedure": set(), "drug": set()} + ) + + for timestamp, code, event_type in events: + delta_day = (timestamp - first_time).days + + if event_type == "diagnoses_icd": + time_to_codes[delta_day]["diagnosis"].add(code) + + elif event_type == "procedures_icd": + time_to_codes[delta_day]["procedure"].add(code) + + elif event_type == "prescriptions": + time_to_codes[delta_day]["drug"].add(code) + + max_day = max(time_to_codes.keys()) + + visits = [] + current_diag, current_proc, current_drug = set(), set(), set() + + # Build cumulative daily visits + for day in range(max_day + 1): + if day in time_to_codes: + current_diag.update(time_to_codes[day]["diagnosis"]) + current_proc.update(time_to_codes[day]["procedure"]) + current_drug.update(time_to_codes[day]["drug"]) + + visits.append( + { + "time": day, + "diagnosis": list(current_diag), + "procedure": list(current_proc), + "drug": list(current_drug), + } + ) + + return visits + + +def build_daily_time_series(patient) -> List[Dict[str, Any]]: + """ + Convert patient events into a daily time series. + + Called by :meth:`DynamicSurvivalTask.__call__` and + :meth:`DynamicSurvivalTask.build_vocab` to normalise any patient + format into a flat list of daily visit dicts before feature encoding. + + Args: + patient: Patient object with visits and event lists. + + Returns: + List of daily aggregated visits. + """ + events = [] + + if hasattr(patient, "data_source"): + return build_daily_time_series_from_df(patient) + + + if hasattr(patient, "get_visits"): + visits = patient.get_visits() + else: + visits = patient.visits.values() + + for visit in visits: + for table in [ + "DIAGNOSES_ICD", + "PROCEDURES_ICD", + "PRESCRIPTIONS", + ]: + for event in visit.event_list_dict.get(table, []): + timestamp = ( + event.timestamp + if event.timestamp is not None + else visit.encounter_time + ) + if timestamp is None: + continue + events.append((timestamp, event.code, event.vocabulary)) + + if not events: + return [] + + events.sort(key=lambda x: x[0]) + first_time = events[0][0] + + time_to_codes = defaultdict( + lambda: {"diagnosis": set(), "procedure": set(), "drug": set()} + ) + + for timestamp, code, vocab in events: + delta_day = (timestamp - first_time).days + + if vocab == "ICD9CM": + time_to_codes[delta_day]["diagnosis"].add(code) + elif vocab == "ICD9PROC": + time_to_codes[delta_day]["procedure"].add(code) + elif vocab == "NDC": + time_to_codes[delta_day]["drug"].add(code) + + max_day = max(time_to_codes.keys()) + + visits = [] + current_diag, current_proc, current_drug = set(), set(), set() + + for day in range(max_day + 1): + if day in time_to_codes: + current_diag.update(time_to_codes[day]["diagnosis"]) + current_proc.update(time_to_codes[day]["procedure"]) + current_drug.update(time_to_codes[day]["drug"]) + + visits.append( + { + "time": day, + "diagnosis": list(current_diag), + "procedure": list(current_proc), + "drug": list(current_drug), + } + ) + + return visits + + +class DynamicSurvivalEngine: + """Core engine for dynamic survival sample generation.""" + + def __init__( + self, + horizon: int = 24, + observation_window: int = 24, + anchor_interval: int = 12, + anchor_strategy: str = "fixed", + ): + """Initialize the engine with survival prediction parameters. + + Args: + horizon: Number of discrete time steps in the prediction window. + observation_window: Look-back window width in days. + anchor_interval: Spacing between anchors under the fixed strategy. + anchor_strategy: "fixed" for evenly spaced anchors or "single" for + one anchor at the earliest valid prediction point. + """ + self.horizon = horizon + self.observation_window = observation_window + self.anchor_interval = anchor_interval + self.anchor_strategy = anchor_strategy + + def generate_anchors( + self, + event_times: List[int], + outcome_time: Optional[int], + censor_time: Optional[int] = None, + ) -> List[int]: + """Generate anchor time points for survival sample generation. + + Anchors define the time points from which prediction windows are + constructed. Under the 'fixed' strategy, anchors are placed at + regular intervals between the first observable time and the event + or censor time. Under the 'single' strategy, only one anchor is + placed at the event or censor time. + + Args: + event_times: Sorted list of visit timestamps in days relative + to the patient's first recorded event. + outcome_time: Day of the observed event, or None if censored. + censor_time: Day of censoring, or None if event was observed. + + Returns: + List of integer anchor times at which samples are generated. + Returns an empty list if no valid anchors can be constructed + (e.g. observation window exceeds available history). + """ + if not event_times: + return [] + + max_time = ( + outcome_time + if outcome_time is not None + else (censor_time if censor_time is not None else max(event_times)) + ) + + start_time = min(event_times) + self.observation_window + + # single: one anchor per patient at the earliest valid prediction point, + # giving the maximum delta. Ablates anchor density vs. fixed while + # testing the hardest (most distant) prediction scenario. + if self.anchor_strategy == "single": + if start_time < max_time: + return [int(start_time)] + return [] + + # fixed (default): evenly spaced anchors across the valid range + if start_time >= max_time: + return [] + anchors = list(range(int(start_time), int(max_time), self.anchor_interval)) + return anchors if anchors else [] + + + def generate_survival_label( + self, + anchor_time: int, + event_time: Optional[int], + censor_time: Optional[int] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Generate a discrete-time survival label vector and risk mask. + + For each horizon step k, y[k] indicates whether the event occurs + exactly k steps after the anchor. The mask encodes which steps + contribute to the survival likelihood: steps after the event or + after censoring are excluded. + + Convention: censor_time is treated as the last observed event-free + step, so mask[delta] = 1 and mask[delta+1:] = 0, mirroring the + event case where the event step itself is included. + + Args: + anchor_time: The anchor time point in days. + event_time: Observed event time in days, or None if censored. + censor_time: Censoring time in days, or None if event observed. + + Returns: + Tuple of (y, mask), each a float32 array of shape (horizon,). + y[k] = 1.0 if the event occurs at step k, else 0.0. + mask[k] = 1.0 if step k is in the risk set, else 0.0. + """ + y = np.zeros(self.horizon, dtype=float) + mask = np.ones(self.horizon, dtype=float) + + if event_time is not None: + delta = int(event_time - anchor_time) + + if delta < 0: + mask[:] = 0 + elif delta < self.horizon: + y[delta] = 1 + mask[delta + 1:] = 0 + + elif censor_time is not None: + delta = int(censor_time - anchor_time) + if delta < self.horizon: + mask[max(0, delta + 1):] = 0 + + return y, mask + + + def process_patient( + self, patient: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Convert a single patient dictionary into survival samples. + + For each anchor time generated from the patient's visit history, + extracts the observation window, encodes features, and generates + the corresponding survival label and mask. + + Args: + patient: Dictionary with keys: + - patient_id (str): unique patient identifier. + - visits (List[Dict]): list of visit dicts, each with + 'time' (int) and 'feature' (np.ndarray) keys. + - outcome_time (Optional[int]): event time in days. + - censor_time (Optional[int]): censor time in days. + + Returns: + List of sample dictionaries, each containing: + - patient_id (str): patient identifier. + - visit_id (str): unique anchor-based sample identifier. + - x (np.ndarray): feature matrix of shape (T, d). + - y (np.ndarray): hazard label vector of shape (horizon,). + - mask (np.ndarray): risk set mask of shape (horizon,). + Returns an empty list if no valid anchors or sequences exist. + """ + samples = [] + + pid = patient.get("patient_id", "unknown") + visits = patient.get("visits", []) + event_time = patient.get("outcome_time") + censor_time = patient.get("censor_time") + + event_times = [v["time"] for v in visits if "time" in v] + anchors = self.generate_anchors(event_times, event_time, censor_time) + + for anchor in anchors: + obs_start = anchor - self.observation_window + seq = [] + + for visit in visits: + if obs_start <= visit["time"] < anchor: + if "feature" not in visit: + continue + + feat = np.concatenate( + [ + np.array( + [(visit["time"] - obs_start) / self.observation_window], + dtype=np.float32, + ), + np.array(visit["feature"], dtype=np.float32), + ] + ) + seq.append(feat) + + if not seq: + continue + + x = np.array(seq, dtype=np.float32) + y, mask = self.generate_survival_label( + anchor, event_time, censor_time + ) + + samples.append( + { + "patient_id": pid, + "visit_id": f"{pid}_{anchor}", + "x": x, + "y": y.astype(np.float32), + "mask": mask.astype(np.float32), + } + ) + + return samples + + +class DynamicSurvivalTask(BaseTask): + """PyHealth-compatible dynamic survival task for early event prediction. + + Implements the anchor-based discrete-time survival formulation from: + Yèche et al. (2024), *Dynamic Survival Analysis for Early Event Prediction*. + arXiv:2403.12818. + + Each patient is converted into one or more survival samples, one per + anchor time point. At each anchor, the model predicts a discrete-time + hazard sequence over a fixed prediction horizon. + + Attributes: + task_name (str): Identifier for this task, used by PyHealth internals. + input_schema (Dict[str, str]): Maps 'x' to 'tensor' processor. + output_schema (Dict[str, str]): Maps 'y' and 'mask' to 'tensor' processors. + use_diag (bool): Whether to include diagnosis codes in features. + use_proc (bool): Whether to include procedure codes in features. + use_drug (bool): Whether to include drug codes in features. + engine (DynamicSurvivalEngine): Core engine handling anchor generation, + label construction, and sample assembly. + diag_vocab (Dict[str, int]): Diagnosis code to index mapping. + proc_vocab (Dict[str, int]): Procedure code to index mapping. + drug_vocab (Dict[str, int]): Drug code to index mapping. + + Example: + >>> from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask + >>> dataset = MockDataset() + >>> task = DynamicSurvivalTask( + ... dataset=dataset, + ... horizon=24, + ... observation_window=24, + ... anchor_strategy="fixed", + ... ) + >>> samples = task(patient) + """ + + task_name: str = "dynamic_survival" + + input_schema = { + "x": "tensor", + } + + output_schema = { + "y": "tensor", + "mask": "tensor", + } + + + def __init__( + self, + dataset, + horizon: int = 24, + observation_window: int = 24, + anchor_interval: int = 12, + anchor_strategy: str = "fixed", + use_diag: bool = True, + use_proc: bool = True, + use_drug: bool = True, + ): + """Initialize the task and build code vocabularies from the dataset. + + Args: + dataset: PyHealth dataset or MockDataset used to build vocabularies. + horizon: Prediction horizon in discrete time steps. + observation_window: Look-back window width in days. + anchor_interval: Anchor spacing in days (fixed strategy only). + anchor_strategy: "fixed" or "single" anchor placement strategy. + use_diag: Include diagnosis codes in features. + use_proc: Include procedure codes in features. + use_drug: Include drug codes in features. + """ + super().__init__() + + self.use_diag = use_diag + self.use_proc = use_proc + self.use_drug = use_drug + + self.engine = DynamicSurvivalEngine( + horizon, observation_window, anchor_interval, anchor_strategy + ) + + self.diag_mapper = DIAG_MAPPER + self.proc_mapper = PROC_MAPPER + self.drug_mapper = DRUG_MAPPER + + self.diag_vocab, self.proc_vocab, self.drug_vocab = ( + self.build_vocab(dataset) + ) + + def build_vocab(self, dataset) -> Tuple[Dict, Dict, Dict]: + """Build code vocabularies from a dataset for feature encoding. + + Iterates over patients in the dataset to collect all unique + diagnosis, procedure, and drug codes, then constructs index + mappings used by encode_multi_hot(). + + Called by :meth:`__init__` immediately after the engine is + constructed, so vocabularies are ready before any patient is + processed. + + Args: + dataset: Dataset object supporting either iter_patients() + (PyHealth datasets) or a patients dict attribute + (MockDataset). + + Returns: + Tuple of (diag_vocab, proc_vocab, drug_vocab), each a dict + mapping code strings to integer indices. + """ + diag_set, proc_set, drug_set = set(), set(), set() + + if hasattr(dataset, "iter_patients"): + patient_iter = dataset.iter_patients() + else: + patient_iter = dataset.patients.values() + + for patient in patient_iter: + visits = build_daily_time_series(patient) + + for visit in visits: + if self.use_diag: + diag_set.update(visit["diagnosis"]) + if self.use_proc: + proc_set.update(visit["procedure"]) + if self.use_drug: + drug_set.update(visit["drug"]) + + self.diag_vocab = {c: i for i, c in enumerate(sorted(diag_set))} + self.proc_vocab = {c: i for i, c in enumerate(sorted(proc_set))} + self.drug_vocab = {c: i for i, c in enumerate(sorted(drug_set))} + + return self.diag_vocab, self.proc_vocab, self.drug_vocab + + def encode_multi_hot(self, codes: List[str], vocab: Dict[str, int]) -> np.ndarray: + """Encode a list of codes as a multi-hot vector using a vocabulary. + + Called by :meth:`__call__` for each visit to convert diagnosis, + procedure, and drug code lists into fixed-length binary feature + vectors before concatenation into the sample's feature matrix. + + Args: + codes: List of code strings to encode (e.g. ICD codes, NDC codes). + vocab: Dictionary mapping code strings to integer indices. + + Returns: + Binary np.ndarray of shape (len(vocab),) where index i is 1.0 + if the corresponding code is present in codes, else 0.0. + Returns a zero vector if vocab is empty or no codes match. + """ + vec = np.zeros(len(vocab)) + for code in codes: + if code in vocab: + vec[vocab[code]] = 1 + return vec + + def __call__(self, patient) -> List[Dict[str, Any]]: + """ + Convert a patient into dynamic survival samples. + + This is the primary entry point called by ``dataset.set_task(task)`` + for every patient in the dataset. It dispatches to the appropriate + processing path based on the patient's type, then delegates to + :meth:`DynamicSurvivalEngine.process_patient`. + + This function supports three types of patient inputs: + 1. Mock patients with visit dictionaries + 2. Dict-style patients (used in tests) + 3. PyHealth dataframe-based patients + + For dataframe-based patients, mortality is extracted using + the 'expire_flag' field from patient-level events. + + Args: + patient (Any): Patient object. + + Returns: + List[Dict[str, Any]]: Survival samples. + """ + + if hasattr(patient, "visits") and isinstance(patient.visits, dict): + visits_list = list(patient.visits.values()) + + if len(visits_list) == 0: + return [] + + processed_visits = [] + start_time = visits_list[0].encounter_time + + for visit in visits_list: + features = [] + + if self.use_diag: + codes = [ + e.code + for e in visit.event_list_dict.get("DIAGNOSES_ICD", []) + ] + features.append( + self.encode_multi_hot(codes, self.diag_vocab) + ) + + if self.use_proc: + codes = [ + e.code + for e in visit.event_list_dict.get("PROCEDURES_ICD", []) + ] + features.append( + self.encode_multi_hot(codes, self.proc_vocab) + ) + + if self.use_drug: + codes = [ + e.code + for e in visit.event_list_dict.get("PRESCRIPTIONS", []) + ] + features.append( + self.encode_multi_hot(codes, self.drug_vocab) + ) + + x = ( + np.concatenate(features).astype(np.float32) + if features + else np.zeros(1, dtype=np.float32) + ) + + time_idx = (visit.encounter_time - start_time).days + + processed_visits.append( + { + "time": time_idx, + "feature": x, + } + ) + + death_time = getattr(patient, "death_datetime", None) + + if death_time: + outcome_time = (death_time - start_time).days + censor_time = None + else: + outcome_time = None + censor_time = processed_visits[-1]["time"] + + patient_dict = { + "patient_id": patient.patient_id, + "visits": processed_visits, + "outcome_time": outcome_time, + "censor_time": censor_time, + } + + return self.engine.process_patient(patient_dict) + + # Dict-style patient (tests) + if isinstance(patient, dict): + return self.engine.process_patient(patient) + + # PyHealth dataframe patient + visits_raw = build_daily_time_series(patient) + if not visits_raw: + return [] + + processed_visits = [] + + for visit in visits_raw: + features = [] + + if self.use_diag: + mapped = [ + m + for c in visit["diagnosis"] + for m in self.diag_mapper.map(c) + if m + ] + features.append( + self.encode_multi_hot(mapped, self.diag_vocab) + ) + + if self.use_proc: + mapped = [ + m + for c in visit["procedure"] + for m in self.proc_mapper.map(c) + if m + ] + features.append( + self.encode_multi_hot(mapped, self.proc_vocab) + ) + + if self.use_drug: + mapped = [ + m + for c in visit["drug"] + for m in self.drug_mapper.map(c) + if m + ] + features.append( + self.encode_multi_hot(mapped, self.drug_vocab) + ) + + x = ( + np.concatenate(features).astype(np.float32) + if features + else np.zeros(1, dtype=np.float32) + ) + + processed_visits.append( + { + "time": visit["time"], + "feature": x, + } + ) + + # Extract mortality signal + death_flag = False + + if hasattr(patient, "get_events"): + for event in patient.get_events(): + if event.event_type == "patients": + if event.attr_dict.get("expire_flag") == "1": + death_flag = True + break + + if death_flag: + outcome_time = processed_visits[-1]["time"] + censor_time = None + else: + outcome_time = None + censor_time = processed_visits[-1]["time"] + + patient_dict = { + "patient_id": getattr(patient, "patient_id", "unknown"), + "visits": processed_visits, + "outcome_time": outcome_time, + "censor_time": censor_time, + } + + return self.engine.process_patient(patient_dict) diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..d7ef74f8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "rdkit", "ogb>=1.3.5", "scikit-learn~=1.7.0", + "scikit-survival>=0.23.0", "networkx", "mne~=1.10.0", "urllib3~=2.5.0", diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..63e0b8e4f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +# IMPORTANT: +# Install PyTorch via conda BEFORE installing these packages: +# conda install pytorch cpuonly torchvision -c pytorch +# (Do NOT install torch via pip, it may cause DLL or compatibility issues) + +pyhealth==1.1.6 +numpy==1.26.4 +pandas==1.5.3 +scipy==1.10.1 +polars==1.39.3 +mne +scikit-learn +matplotlib diff --git a/tests/core/test_dynamic_survival.py b/tests/core/test_dynamic_survival.py new file mode 100644 index 000000000..ecf79afef --- /dev/null +++ b/tests/core/test_dynamic_survival.py @@ -0,0 +1,600 @@ +# Authors: Skyler Lehto (lehto2@illinois.edu), +# Ryan Bradley (ryancb3@illinois.edu), +# Weonah Choi (weonahc2@illinois.edu) +# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024) +# Link: https://arxiv.org/abs/2403.12818 +# Description: Unit tests for DynamicSurvivalTask using synthetic data. + +""" +Unit tests for DynamicSurvivalTask. + +This module verifies: +- Sample generation through dataset.set_task() +- Correct label behavior (event and censor) +- Feature extraction pipeline +- Edge cases (empty patient) +- Core engine functionality (anchors and labels) + +All tests use synthetic data and run quickly. +""" + +import json +import shutil +import tempfile +import unittest +from datetime import datetime, timedelta + +import numpy as np + +from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask + + +# Mock Classes + +class MockEvent: + """Simple mock event object.""" + + def __init__(self, code, timestamp, vocabulary): + self.code = code + self.timestamp = timestamp + self.vocabulary = vocabulary + + +class MockVisit: + """Mock visit containing EHR events.""" + + def __init__(self, time, diagnosis=None, procedure=None, drug=None): + self.encounter_time = time + self.event_list_dict = { + "DIAGNOSES_ICD": [ + MockEvent(c, time, "ICD9CM") for c in (diagnosis or []) + ], + "PROCEDURES_ICD": [ + MockEvent(c, time, "ICD9PROC") for c in (procedure or []) + ], + "PRESCRIPTIONS": [ + MockEvent(c, time, "NDC") for c in (drug or []) + ], + } + + +class MockPatient: + """Mock patient object.""" + + def __init__(self, pid, visits_data, death_time=None): + self.patient_id = pid + self.visits = { + f"v{i}": MockVisit(**v) for i, v in enumerate(visits_data) + } + self.death_datetime = death_time + + +class MockDataset: + """Minimal dataset wrapper for testing.""" + + def __init__(self, patients=None): + patients = patients or [] + self.patients = {p.patient_id: p for p in patients} + + def set_task(self, task): + """Apply task to all patients.""" + samples = [] + for patient in self.patients.values(): + out = task(patient) + if out: + samples.extend(out) + return samples + + +# Helper + +def create_patients(n=10): + """ + Creates a small deterministic synthetic dataset for unit tests. + + Characteristics: + - Fixed visit times + - Predictable event/censor patterns + - Designed for fast, reproducible testing + + NOTE: This is intentionally simple and NOT meant for modeling experiments. + """ + patients = [] + for i in range(n): + visits = [{"time": t, "feature": np.zeros(1)} for t in range(5, 50, 5)] + + patients.append({ + "patient_id": f"p{i}", + "visits": visits, + "outcome_time": 60 if i % 2 == 0 else None, + "censor_time": 55 if i % 2 == 1 else None, + }) + return patients + + +# Test Suite + +class TestDynamicSurvivalTask(unittest.TestCase): + + def test_dynamic_survival_event(self): + """Test sample generation when event occurs.""" + base_time = datetime(2025, 4, 1) + + patient = MockPatient( + pid="P1", + death_time=base_time + timedelta(days=2), + visits_data=[ + {"time": base_time, "diagnosis": ["4019"]}, + {"time": base_time + timedelta(days=1), + "diagnosis": ["4101"]}, + ], + ) + + dataset = MockDataset([patient]) + + task = DynamicSurvivalTask( + dataset, + horizon=5, + observation_window=1, + anchor_interval=1, + ) + + samples = dataset.set_task(task) + + self.assertGreater(len(samples), 0) + + s = samples[0] + + self.assertEqual(s["x"].ndim, 2) + self.assertEqual(s["y"].shape, (5,)) + self.assertEqual(s["mask"].shape, (5,)) + + self.assertEqual(s["x"].dtype, np.float32) + self.assertEqual(s["y"].dtype, np.float32) + self.assertEqual(s["mask"].dtype, np.float32) + + # Event at delta=1 + self.assertEqual(s["y"][1], 1.0) + self.assertEqual(s["mask"][2], 0.0) + + # At most one event (DSA constraint) + self.assertLessEqual(float(np.sum(s["y"])), 1) + # Mask must be binary + self.assertTrue(np.all((s["mask"] == 0) | (s["mask"] == 1))) + + def test_dynamic_survival_censor(self): + """Test behavior for censored patient.""" + base_time = datetime(2025, 4, 1) + + patient = MockPatient( + pid="P2", + death_time=None, + visits_data=[ + {"time": base_time, "diagnosis": ["25000"]}, + {"time": base_time + timedelta(days=1), + "diagnosis": ["4019"]}, + ], + ) + + dataset = MockDataset([patient]) + + task = DynamicSurvivalTask( + dataset, + horizon=5, + observation_window=1, + anchor_interval=1, + ) + + samples = dataset.set_task(task) + + self.assertIsInstance(samples, list) + + # Censor case may produce zero samples (valid) + if len(samples) == 0: + return + + s = samples[0] + + # Mask should contain zeros due to censoring + self.assertTrue(np.any(s["mask"] == 0)) + + def test_empty_patient(self): + """Test patient with no visits.""" + patient = MockPatient( + pid="P3", + death_time=None, + visits_data=[], + ) + + dataset = MockDataset([patient]) + + task = DynamicSurvivalTask( + dataset, + horizon=5, + observation_window=1, + anchor_interval=1, + ) + + samples = dataset.set_task(task) + + self.assertEqual(len(samples), 0) + + def test_generate_survival_label_basic(self): + """Test correctness of survival label generation.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5) + + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=12, + ) + + self.assertEqual(y[2], 1) + self.assertEqual(mask[3], 0) + + def test_generate_anchors_basic(self): + """Test anchor generation logic.""" + task = DynamicSurvivalTask(MockDataset(), observation_window=1) + + anchors = task.engine.generate_anchors( + event_times=[0, 1], + outcome_time=3, + ) + + self.assertGreater(len(anchors), 0) + + def test_end_to_end_pipeline_object_patients(self): + """Test full pipeline with multiple PyHealth-style patients.""" + base_time = datetime(2025, 4, 1) + + patients = [ + MockPatient( + pid="P1", + death_time=base_time + timedelta(days=2), + visits_data=[ + {"time": base_time, "diagnosis": ["4019"]}, + {"time": base_time + timedelta(days=1), + "diagnosis": ["4101"]}, + ], + ), + MockPatient( + pid="P2", + death_time=None, + visits_data=[ + {"time": base_time, "diagnosis": ["25000"]}, + {"time": base_time + timedelta(days=1), + "diagnosis": ["4019"]}, + ], + ), + ] + + dataset = MockDataset(patients) + + task = DynamicSurvivalTask( + dataset, + horizon=5, + observation_window=1, + anchor_interval=1, + ) + + samples = dataset.set_task(task) + + self.assertGreater(len(samples), 0) + + for s in samples: + self.assertEqual(s["x"].ndim, 2) + self.assertEqual(s["y"].shape[0], 5) + self.assertEqual(s["mask"].shape[0], 5) + self.assertLessEqual(float(np.sum(s["y"])), 1) + self.assertTrue(np.all((s["mask"] == 0) | (s["mask"] == 1))) + + def test_multiple_patients_processing(self): + """Test engine processes a batch of dict-based patients without errors.""" + task = DynamicSurvivalTask(MockDataset()) + patients = create_patients(3) + + all_samples = [] + for p in patients: + all_samples.extend(task.engine.process_patient(p)) + + self.assertGreater(len(all_samples), 0) + + def test_censoring_mask_fixed(self): + """Test censoring mask is correctly truncated after the censor step.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5) + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=None, + censor_time=12, # delta = 2 + ) + + # Convention: censor_time is the last observed event-free step. + # Steps 0..delta are included in the risk set (mask=1). + # Steps delta+1.. are excluded (mask=0). + # This mirrors the event case where mask[delta+1:] = 0. + self.assertTrue(np.all(mask[:3] == 1)) # steps 0,1,2 included + self.assertTrue(np.all(mask[3:] == 0)) # steps 3,4 excluded + self.assertEqual(float(np.sum(y)), 0) # no event for censored patient + + def test_censoring_mask_single(self): + """Test censoring mask behavior is independent of anchor_strategy.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5, anchor_strategy="single") + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=None, + censor_time=12, # delta = 2, same convention as fixed + ) + + # anchor_strategy does not affect label generation; + # only anchor placement does. + # With delta=2: steps 0,1,2 included; steps 3,4 excluded. + self.assertTrue(np.all(mask[:3] == 1)) + self.assertTrue(np.all(mask[3:] == 0)) + self.assertEqual(float(np.sum(y)), 0) + + def test_single_anchor_strategy(self): + """Single anchor strategy produces exactly one anchor.""" + task = DynamicSurvivalTask( + MockDataset(), anchor_strategy="single", observation_window=1 + ) + + anchors = task.engine.generate_anchors([5, 10], outcome_time=20) + + self.assertEqual(len(anchors), 1) + + def test_empty_events(self): + """Test that a patient with no visits produces no samples.""" + task = DynamicSurvivalTask(MockDataset()) + patient = { + "patient_id": "p", + "visits": [], + } + + samples = task.engine.process_patient(patient) + + self.assertEqual(samples, []) + + def test_output_format(self): + """Test that output samples contain x, y, and mask as numpy arrays.""" + task = DynamicSurvivalTask(MockDataset(), observation_window=5, horizon=5) + + patient = { + "patient_id": "p1", + "visits": [{"time": t, "feature": np.zeros(1)} for t in [5, 10, 15]], + "outcome_time": 20, + } + + samples = task.engine.process_patient(patient) + + self.assertGreater(len(samples), 0, "No samples were generated for the patient") + s = samples[0] + self.assertIn("x", s) + self.assertIn("y", s) + self.assertIn("mask", s) + self.assertIsInstance(s["x"], np.ndarray) + self.assertIsInstance(s["y"], np.ndarray) + self.assertIsInstance(s["mask"], np.ndarray) + + def test_event_before_anchor(self): + """Test that an event occurring before the anchor zeroes the entire mask.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5) + + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=8, + ) + + self.assertTrue(np.all(mask == 0)) + + def test_event_within_horizon(self): + """Test label and mask values when the event falls inside the horizon.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5) + + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=12, + ) + + # delta = 2 + self.assertEqual(y[2], 1) + self.assertEqual(float(np.sum(y)), 1) + self.assertTrue(np.all(mask[:3] == 1)) + self.assertTrue(np.all(mask[3:] == 0)) + + def test_event_outside_horizon(self): + """Test event beyond horizon produces all-zero y and all-one mask.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5) + + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=20, + ) + + self.assertEqual(float(np.sum(y)), 0) + self.assertTrue(np.all(mask == 1)) + + def test_no_valid_anchors(self): + """Test observation window larger than patient history yields no samples.""" + task = DynamicSurvivalTask(MockDataset(), observation_window=100) + + patient = { + "patient_id": "p1", + "visits": [ + {"time": 1, "feature": np.zeros(1)}, + {"time": 2, "feature": np.zeros(1)}, + ], + "outcome_time": 3, + } + + samples = task.engine.process_patient(patient) + + self.assertEqual(samples, []) + + def test_label_shape_consistency(self): + """Test that y and mask shapes match the configured horizon.""" + task = DynamicSurvivalTask(MockDataset(), horizon=7) + + y, mask = task.engine.generate_survival_label( + anchor_time=10, + event_time=15, + ) + + self.assertEqual(y.shape, (7,)) + self.assertEqual(mask.shape, (7,)) + + def test_full_pipeline_shapes(self): + """Test output array shapes across all samples from a multi-visit patient.""" + task = DynamicSurvivalTask(MockDataset(), horizon=6) + + patient = { + "patient_id": "p1", + "visits": [{"time": t, "feature": np.zeros(1)} for t in range(5, 50, 5)], + "outcome_time": 60, + } + + samples = task.engine.process_patient(patient) + + for s in samples: + self.assertEqual(s["y"].shape[0], 6) + self.assertEqual(s["mask"].shape[0], 6) + self.assertEqual(s["x"].ndim, 2) + + def test_anchor_with_no_observation_window(self): + """Test patient with visits only before the window still returns a list.""" + task = DynamicSurvivalTask(MockDataset(), observation_window=10) + + patient = { + "patient_id": "p1", + "visits": [{"time": 5, "feature": np.zeros(1)}], # before window + "outcome_time": 20, + } + + samples = task.engine.process_patient(patient) + + self.assertIsInstance(samples, list) + + def test_anchor_respects_censor_time(self): + """Test that no anchor is placed at or after the censor time.""" + task = DynamicSurvivalTask(MockDataset(), anchor_interval=5) + + anchors = task.engine.generate_anchors( + event_times=[5, 10, 15], + outcome_time=None, + censor_time=20, + ) + + self.assertTrue(all(a < 20 for a in anchors)) + + def test_end_to_end_pipeline_dict_patients(self): + """Test full pipeline with synthetic dict-based patients.""" + task = DynamicSurvivalTask(MockDataset()) + + patients = create_patients(5) + samples = [] + + for p in patients: + samples.extend(task.engine.process_patient(p)) + + self.assertGreater(len(samples), 0) + + for s in samples: + self.assertGreater(s["x"].shape[0], 0) + self.assertLessEqual(float(s["y"].sum()), 1) + self.assertTrue(np.all((s["mask"] == 0) | (s["mask"] == 1))) + + def test_uses_temporary_directory(self): + """Verify task output can be written to and cleaned up from a temp directory.""" + task = DynamicSurvivalTask(MockDataset(), horizon=5, observation_window=5) + + patient = { + "patient_id": "p_tmp", + "visits": [{"time": t, "feature": np.zeros(1)} for t in range(5, 30, 5)], + "outcome_time": 35, + } + + samples = task.engine.process_patient(patient) + self.assertGreater(len(samples), 0) + + tmp_dir = tempfile.mkdtemp() + try: + out_path = tmp_dir + "/samples.json" + with open(out_path, "w") as f: + json.dump([s["visit_id"] for s in samples], f) + + with open(out_path) as f: + contents = json.load(f) + + self.assertEqual(len(contents), len(samples)) + finally: + shutil.rmtree(tmp_dir) + + def test_mock_patient_cohort(self): + """Test pipeline with 4 MockPatients covering mixed event/censor cases.""" + base_time = datetime(2025, 1, 1) + + patients = [] + for i in range(4): + visits_data = [ + {"time": base_time + timedelta(days=d), "diagnosis": [str(1000 + i)]} + for d in [0, 5, 10] + ] + # Alternate event / censored patients + death_time = base_time + timedelta(days=15) if i % 2 == 0 else None + patients.append(MockPatient( + pid=f"MP{i}", + visits_data=visits_data, + death_time=death_time, + )) + + dataset = MockDataset(patients) + task = DynamicSurvivalTask( + dataset, horizon=10, observation_window=5, anchor_interval=3 + ) + samples = dataset.set_task(task) + + self.assertGreater(len(samples), 0) + + for s in samples: + self.assertEqual(s["x"].ndim, 2) + self.assertEqual(s["y"].shape, (10,)) + self.assertEqual(s["mask"].shape, (10,)) + self.assertLessEqual(float(np.sum(s["y"])), 1) + self.assertTrue(np.all((s["mask"] == 0) | (s["mask"] == 1))) + + for s in samples: + self.assertGreater(s["x"].shape[0], 0) + self.assertLessEqual(float(s["y"].sum()), 1) + self.assertTrue(np.all((s["mask"] == 0) | (s["mask"] == 1))) + + def test_feature_flags_use_proc_false(self): + """Test that disabling procedure codes still produces valid samples.""" + base_time = datetime(2025, 4, 1) + + patient = MockPatient( + pid="P_flags", + death_time=base_time + timedelta(days=3), + visits_data=[ + {"time": base_time, "diagnosis": ["4019"], "procedure": ["0011"]}, + {"time": base_time + timedelta(days=1), "diagnosis": ["4101"]}, + ], + ) + + dataset = MockDataset([patient]) + + task = DynamicSurvivalTask( + dataset, + horizon=5, + observation_window=1, + anchor_interval=1, + use_proc=False, + ) + + samples = dataset.set_task(task) + + self.assertIsInstance(samples, list) + if len(samples) > 0: + self.assertEqual(samples[0]["x"].ndim, 2) + self.assertEqual(samples[0]["y"].shape, (5,)) + + +if __name__ == "__main__": + unittest.main()