Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
51945dc
initial dynamic survival task impl
rbradley813 Apr 12, 2026
4370e93
add dynamic survival task with tests and ablation
WeonahChoi Apr 14, 2026
07ecf21
add requirements file
WeonahChoi Apr 14, 2026
d769c15
refine dynamic survival tests and add validation checks
WeonahChoi Apr 14, 2026
126dc0f
use mockdatasets and tasks.engine API in tests
rbradley813 Apr 15, 2026
81ab2cb
merge duplicate test files
rbradley813 Apr 15, 2026
4357916
Add requirements for reproducible environment
WeonahChoi Apr 15, 2026
29cbf58
Update DynamicSurvivalTask for PyHealth main compatibility- Support M…
WeonahChoi Apr 15, 2026
049240c
Add prior ablation experiment to dynamic survival example
WeonahChoi Apr 16, 2026
729441d
Update DynamicSurvivalTask and fix .gitignore; all tests passing and …
WeonahChoi Apr 16, 2026
c9845ab
Fix DSA event handling and schema- Correct event_time extraction from…
WeonahChoi Apr 16, 2026
7d663a4
initial team_ablation_results script
rbradley813 Apr 17, 2026
a921877
implement centralized per-horizon bias initialization
rbradley813 Apr 17, 2026
4d33abe
update c-index calculation to use censored calculation in scikit-surv…
rbradley813 Apr 17, 2026
aec7e08
Clean up ablation study by removing prior-based experiments and focus…
WeonahChoi Apr 17, 2026
6a8af5b
fix off-by-one issue in anchor masking
rbradley813 Apr 18, 2026
33456a0
Put ablation tests under main() function and use synthetic patients
skylerl2 Apr 18, 2026
0d996ea
Fixed censoring mask error
skylerl2 Apr 18, 2026
0d15602
Updated usage example
skylerl2 Apr 18, 2026
8f90063
Updated functions with docstrings and return types
skylerl2 Apr 18, 2026
0f0480a
Updated path and imported mock classes to fix missing dataset in abla…
skylerl2 Apr 19, 2026
54d104b
Fixed synthetic_dataset import
skylerl2 Apr 19, 2026
84af2c7
Updated docstrings for class and function
skylerl2 Apr 19, 2026
3e46814
Updated docstrings
skylerl2 Apr 19, 2026
8827866
Commented out MIMIC import since synthetic data used
skylerl2 Apr 19, 2026
c40f142
add censoring task for both anchor types
rbradley813 Apr 19, 2026
59206ea
update c-index calc
rbradley813 Apr 19, 2026
83e9476
single anchor strategy uses earliest prediction point
rbradley813 Apr 19, 2026
974421c
Fixed censoring mask single and single anchor strategy, and added thr…
skylerl2 Apr 19, 2026
5a18dca
cleanup
rbradley813 Apr 20, 2026
f4f3932
add required code file headers
rbradley813 Apr 20, 2026
65ea2e7
add / update docstrings for public methods
rbradley813 Apr 20, 2026
9172b1c
replace pytest with unittest and cleanup
rbradley813 Apr 20, 2026
0d75a69
replace bare asserts with unittest asserts
rbradley813 Apr 20, 2026
a7b0218
implement module import safeguard
rbradley813 Apr 20, 2026
7eaf297
PEP8 cleanup
rbradley813 Apr 20, 2026
5eec7c8
cleanup
rbradley813 Apr 20, 2026
cfdb3a7
remove personal entry from .gitignore
rbradley813 Apr 20, 2026
58132d6
docstring and comment cleanup
rbradley813 Apr 20, 2026
c1cbf67
cap patient data when testing
rbradley813 Apr 20, 2026
b2a2bae
move mock data generation into separate class
rbradley813 Apr 20, 2026
734e9be
cleanup: update return type
rbradley813 Apr 20, 2026
22e0fbd
cleanup: .gitignore
rbradley813 Apr 20, 2026
9d9e2e0
remove dev flag: vocab cap
rbradley813 Apr 20, 2026
c2e2c25
cleanup: dynamic survival task
rbradley813 Apr 20, 2026
2403bd6
Fix type consistency, clean example, add optional C-index, and update…
WeonahChoi Apr 20, 2026
12e6013
fix variability in results: force vocab order so hashing is consistent
rbradley813 Apr 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
Dynamic Survival Analysis <tasks/pyhealth.tasks.dynamic_survival>
68 changes: 68 additions & 0 deletions docs/api/tasks/pyhealth.tasks.dynamic_survival.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
DynamicSurvivalTask
==================

This module implements a dynamic survival analysis task for early event prediction.

The task follows the anchor-based discrete-time survival formulation proposed in:

Yèche et al. (2024), *Dynamic Survival Analysis for Early Event Prediction*.

Key Features
------------
- Multiple anchors per patient
- Discrete-time hazard prediction
- Support for censoring
- Configurable observation windows and anchor strategies

Output Format
-------------
Each processed sample contains:

- **patient_id**: unique patient identifier
- **visit_id**: unique anchor-based visit ID
- **x**: input features (temporal sequence)
- **y**: hazard label vector (0/1)
- **mask**: indicates valid risk set:
- 1 = patient is at risk at this timestep
- 0 = timestep excluded (post-event or post-censoring)

Usage Example
-------------

.. code-block:: python

from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask

# Minimal dataset wrapper (MockDataset or a real PyHealth dataset)
class MockDataset:
def __init__(self):
self.patients = {}

dataset = MockDataset()

task = DynamicSurvivalTask(
dataset=dataset,
observation_window=24,
horizon=24,
anchor_strategy="fixed",
)

# Apply to a patient object
samples = task(patient)

Example Output
--------------

Each sample:

- x: shape (T, d)
- y: shape (horizon,)
- mask: shape (horizon,)

API Reference
-------------

.. autoclass:: pyhealth.tasks.dynamic_survival.DynamicSurvivalTask
:members:
:undoc-members:
:show-inheritance:
164 changes: 164 additions & 0 deletions examples/dynamic_survival_ablation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Authors: Skyler Lehto (lehto2@illinois.edu),
# Ryan Bradley (ryancb3@illinois.edu),
# Weonah Choi (weonahc2@illinois.edu)
# Paper: Dynamic Survival Analysis for Early Event Prediction (Yèche et al., 2024)
# Link: https://arxiv.org/abs/2403.12818
# Description: Ablation study for observation window size on synthetic patients.

"""
Ablation Study: Effect of Observation Window Length

We vary observation window sizes (12, 24, 48 hours) and
measure performance using masked BCE and MSE.

This demonstrates how task configuration (NOT model complexity)
impacts predictive performance.
"""

import sys
import os
from datetime import datetime, timedelta

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import numpy as np
import torch
import torch.nn as nn

from pyhealth.tasks.dynamic_survival import DynamicSurvivalTask
from examples.synthetic_dataset import generate_synthetic_dataset
from examples.mock_ehr import MockEvent, MockVisit, MockPatient, MockDataset


# Convert synthetic dict → MockPatient

def convert_to_mock_patients(patients_dict):
base_time = datetime(2025, 1, 1)

mock_patients = []

for p in patients_dict:
visits_data = []

for v in p["visits"]:
visits_data.append({
"time": base_time + timedelta(days=v["time"]),
"diagnosis": ["0000"], # dummy code for vocab
})

death_time = None
if p.get("outcome_time") is not None:
death_time = base_time + timedelta(days=p["outcome_time"])

mock_patients.append(
MockPatient(
pid=p["patient_id"],
visits_data=visits_data,
death_time=death_time,
)
)

return mock_patients


# Model

class SimpleModel(nn.Module):
def __init__(self, input_dim=2, hidden_dim=8, horizon=24):
super().__init__()
self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, horizon)

def forward(self, x):
_, h = self.rnn(x)
return torch.sigmoid(self.fc(h.squeeze(0)))


# Utils

def prepare_batch(samples):
X, Y, M = [], [], []

for s in samples:
X.append(s["x"])
Y.append(s["y"])
M.append(s["mask"])

if len(X) == 0:
raise ValueError("No valid samples generated.")

max_len = max(len(x) for x in X)

X_pad = []
for x in X:
pad = np.zeros((max_len - len(x), x.shape[1]))
X_pad.append(np.vstack([x, pad]))

return (
torch.tensor(np.array(X_pad), dtype=torch.float32),
torch.tensor(np.array(Y), dtype=torch.float32),
torch.tensor(np.array(M), dtype=torch.float32),
)


def train_and_eval(samples):
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

X, Y, M = prepare_batch(samples)

for _ in range(5):
pred = model(X)
loss = -(Y * torch.log(pred + 1e-8) +
(1 - Y) * torch.log(1 - pred + 1e-8))
loss = (loss * M).sum() / M.sum()

optimizer.zero_grad()
loss.backward()
optimizer.step()

with torch.no_grad():
pred = model(X)

bce = -(Y * torch.log(pred + 1e-8) +
(1 - Y) * torch.log(1 - pred + 1e-8))
bce = (bce * M).sum() / M.sum()

mse = ((pred - Y) ** 2 * M).sum() / M.sum()

return {"bce": bce.item(), "mse": mse.item()}


# Main Experiment

def main():
patients_raw = generate_synthetic_dataset(50)
patients = convert_to_mock_patients(patients_raw)
dataset = MockDataset(patients)

windows = [12, 24, 48]
results = {}

print("\n=== Ablation Results ===")

for w in windows:
task = DynamicSurvivalTask(
dataset=dataset,
observation_window=w,
horizon=24,
)

samples = dataset.set_task(task)

if len(samples) == 0:
print(f"Skipping window={w}, no samples")
continue

score = train_and_eval(samples)
results[w] = score

print(f"Window={w} | BCE={score['bce']:.4f} | MSE={score['mse']:.4f}")


if __name__ == "__main__":
main()
Loading