Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ data/physionet.org/

# VSCode settings
.vscode/
.codex

# Model weight files (large binaries, distributed separately)
weightfiles/
4 changes: 1 addition & 3 deletions pyhealth/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,13 @@ def __init__(
):
vocab_size = len(processor.code_vocab)

# For NestedSequenceProcessor and DeepNestedSequenceProcessor, don't use padding_idx
# because empty visits/groups need non-zero embeddings.
if isinstance(
processor, (NestedSequenceProcessor, DeepNestedSequenceProcessor)
):
self.embedding_layers[field_name] = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=None,
padding_idx=0,
)
else:
self.embedding_layers[field_name] = nn.Embedding(
Expand Down
68 changes: 59 additions & 9 deletions pyhealth/tasks/drug_recommendation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
from typing import Any, Dict, List
from typing import Any, Dict, Iterable, List, Optional

import polars as pl

from pyhealth.data import Patient, Visit
from pyhealth.medcode import CrossMap
from .base_task import BaseTask


_NDC_TO_ATC3_MAPPER = None
_NDC_TO_ATC3_CACHE: Dict[str, List[str]] = {}


def _get_ndc_to_atc3_mapper():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to mention, sequence processor can also handle the code mappings now too. Don't think it matters that much here, but worth sharing.

class SequenceProcessor(FeatureProcessor, TokenProcessorInterface):
    """Feature processor for encoding categorical sequences.

    Encodes medical codes (e.g., diagnoses, procedures) into numerical
    indices. Supports single or multiple tokens and can build vocabulary
    on the fly if not provided.

    Args:
        code_mapping: optional tuple of (source_vocabulary, target_vocabulary)
            to map raw codes to a grouped vocabulary before tokenizing.
            Uses ``pyhealth.medcode.CrossMap`` internally. For example,
            ``("ICD9CM", "CCSCM")`` maps ~128K ICD-9 diagnosis codes to
            ~280 CCS categories, and ``("NDC", "ATC")`` maps ~940K drug
            codes to ~5K ATC categories. When None (default), codes are
            used as-is with no change to existing behavior.

    Examples:
        >>> proc = SequenceProcessor()  # no mapping, same as before
        >>> proc = SequenceProcessor(code_mapping=("ICD9CM", "CCSCM"))
    """

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit tricky because NestedSequence & MultiLabel does not yet support this. May be better to left this into a future PR.

global _NDC_TO_ATC3_MAPPER
if _NDC_TO_ATC3_MAPPER is None:
_NDC_TO_ATC3_MAPPER = CrossMap.load("NDC", "ATC")
return _NDC_TO_ATC3_MAPPER


def _is_missing_ndc(code: Any) -> bool:
if code is None:
return True
code = str(code).strip()
return code == "" or code == "0" or code.lower() in {"nan", "none", "<na>"}


def _map_ndc_list_to_atc3(
ndc_codes: Iterable[Any],
mapper: Optional[Any] = None,
) -> List[str]:
"""Maps MIMIC prescription NDCs to stable, deduplicated ATC-3 labels."""
mapper = _get_ndc_to_atc3_mapper() if mapper is None else mapper
drugs: List[str] = []
seen = set()

for ndc in ndc_codes:
if _is_missing_ndc(ndc):
continue
ndc = str(ndc).strip()
if ndc not in _NDC_TO_ATC3_CACHE:
_NDC_TO_ATC3_CACHE[ndc] = mapper.map(ndc, target_kwargs={"level": 3})
mapped_codes = _NDC_TO_ATC3_CACHE[ndc]
for code in mapped_codes:
if code is None:
continue
code = str(code).strip()
if code and code not in seen:
drugs.append(code)
seen.add(code)

return drugs


class DrugRecommendationMIMIC3(BaseTask):
"""Task for drug recommendation using MIMIC-III dataset.

Expand Down Expand Up @@ -35,6 +81,10 @@ class DrugRecommendationMIMIC3(BaseTask):
}
output_schema: Dict[str, str] = {"drugs": "multilabel"}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cache_version = "ndc_to_atc3_v1"

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Process a patient to create drug recommendation samples.

Expand Down Expand Up @@ -92,8 +142,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
prescriptions.select(pl.col("prescriptions/ndc")).to_series().to_list()
)

# ATC 3 level (first 4 characters)
drugs = [drug[:4] for drug in drugs if drug]
drugs = _map_ndc_list_to_atc3(drugs)

# Exclude visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
Expand Down Expand Up @@ -173,6 +222,10 @@ class DrugRecommendationMIMIC4(BaseTask):
}
output_schema: Dict[str, str] = {"drugs": "multilabel"}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cache_version = "ndc_to_atc3_v1"

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Process a patient to create drug recommendation samples.

Expand Down Expand Up @@ -240,8 +293,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
prescriptions.select(pl.col("prescriptions/ndc")).to_series().to_list()
)

# ATC 3 level (first 4 characters)
drugs = [drug[:4] for drug in drugs if drug]
drugs = _map_ndc_list_to_atc3(drugs)

# Exclude visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
Expand Down Expand Up @@ -332,8 +384,7 @@ def drug_recommendation_mimic3_fn(patient: Patient):
conditions = visit.get_code_list(table="DIAGNOSES_ICD")
procedures = visit.get_code_list(table="PROCEDURES_ICD")
drugs = visit.get_code_list(table="PRESCRIPTIONS")
# ATC 3 level
drugs = [drug[:4] for drug in drugs]
drugs = _map_ndc_list_to_atc3(drugs)
# exclude: visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
continue
Expand Down Expand Up @@ -413,8 +464,7 @@ def drug_recommendation_mimic4_fn(patient: Patient):
conditions = visit.get_code_list(table="diagnoses_icd")
procedures = visit.get_code_list(table="procedures_icd")
drugs = visit.get_code_list(table="prescriptions")
# ATC 3 level
drugs = [drug[:4] for drug in drugs]
drugs = _map_ndc_list_to_atc3(drugs)
# exclude: visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
continue
Expand Down
174 changes: 174 additions & 0 deletions tests/core/test_drug_recommendation_atc3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import csv
import gzip
import shutil
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

import pyhealth.tasks.drug_recommendation as drug_rec
from pyhealth.datasets import MIMIC3Dataset, MIMIC4Dataset
from pyhealth.tasks import DrugRecommendationMIMIC3, DrugRecommendationMIMIC4


class FakeNDCToATC3Map:
def __init__(self):
self.calls = []
self.mapping = {
"11111111111": ["A10B"],
"22222222222": ["C03C", "C03C"],
"33333333333": ["N02B"],
}

def map(self, ndc, target_kwargs=None):
self.calls.append((ndc, target_kwargs))
return self.mapping.get(ndc, [])


class TestDrugRecommendationATC3(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.resources_root = Path(__file__).parents[2] / "test-resources" / "core"

def setUp(self):
drug_rec._NDC_TO_ATC3_MAPPER = None
drug_rec._NDC_TO_ATC3_CACHE.clear()
self.mapper = FakeNDCToATC3Map()
patcher = patch(
"pyhealth.tasks.drug_recommendation.CrossMap.load",
return_value=self.mapper,
)
self.addCleanup(patcher.stop)
self.crossmap_load = patcher.start()
self.temp_dirs = []

def tearDown(self):
drug_rec._NDC_TO_ATC3_MAPPER = None
drug_rec._NDC_TO_ATC3_CACHE.clear()
for temp_dir in self.temp_dirs:
temp_dir.cleanup()

def _copy_demo(self, demo_name):
temp_dir = tempfile.TemporaryDirectory()
self.temp_dirs.append(temp_dir)
source = self.resources_root / demo_name
target = Path(temp_dir.name) / demo_name
shutil.copytree(source, target)
return target, temp_dir

def _rewrite_prescription_ndcs(self, path, replacements):
opener = gzip.open if path.suffix == ".gz" else open
with opener(path, "rt", newline="") as f:
reader = csv.DictReader(f)
rows = list(reader)
fieldnames = reader.fieldnames

if fieldnames is None:
raise ValueError(f"No CSV header found in {path}")

counts = {hadm_id: 0 for hadm_id in replacements}
templates = {}
rewritten_rows = []
for row in rows:
hadm_id = str(row["hadm_id"])
if hadm_id in replacements:
templates.setdefault(hadm_id, row.copy())
index = counts[hadm_id]
ndcs = replacements[hadm_id]
row["ndc"] = ndcs[index] if index < len(ndcs) else "99999999999"
counts[hadm_id] += 1
rewritten_rows.append(row)

for hadm_id, ndcs in replacements.items():
if hadm_id not in templates:
raise ValueError(f"No prescription rows found for hadm_id={hadm_id}")
while counts[hadm_id] < len(ndcs):
row = templates[hadm_id].copy()
if "row_id" in row:
row["row_id"] = str(10_000_000 + len(rewritten_rows))
row["ndc"] = ndcs[counts[hadm_id]]
rewritten_rows.append(row)
counts[hadm_id] += 1

with opener(path, "wt", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rewritten_rows)

def _assert_atc3_samples(self, samples, first_hadm_id, second_hadm_id):
by_visit = {str(sample["visit_id"]): sample for sample in samples}
self.assertIn(first_hadm_id, by_visit)
self.assertIn(second_hadm_id, by_visit)

first_sample = by_visit[first_hadm_id]
second_sample = by_visit[second_hadm_id]
self.assertEqual(first_sample["drugs"], ["A10B", "C03C"])
self.assertEqual(second_sample["drugs"], ["N02B"])
self.assertNotIn("1111", first_sample["drugs"])
self.assertNotIn("2222", first_sample["drugs"])
self.assertNotIn("3333", second_sample["drugs"])
self.assertNotIn("0", first_sample["drugs"])
self.assertNotIn("9999", first_sample["drugs"])

def test_mimic3_demo_drug_recommendation_maps_ndc_to_atc3(self):
demo_path, cache_dir = self._copy_demo("mimic3demo")
self._rewrite_prescription_ndcs(
demo_path / "PRESCRIPTIONS.csv.gz",
{
"142582": [
"11111111111",
"22222222222",
"11111111111",
"0",
"99999999999",
],
"122098": ["33333333333", "", "<NA>"],
},
)
dataset = MIMIC3Dataset(
root=str(demo_path),
tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
cache_dir=cache_dir.name,
)

samples = DrugRecommendationMIMIC3()(dataset.get_patient("10059"))

self.crossmap_load.assert_called_once_with("NDC", "ATC")
self._assert_atc3_samples(samples, "142582", "122098")
self.assertTrue(
all(kwargs == {"level": 3} for _, kwargs in self.mapper.calls)
)

def test_mimic4_demo_drug_recommendation_maps_ndc_to_atc3(self):
demo_path, cache_dir = self._copy_demo("mimic4demo")
self._rewrite_prescription_ndcs(
demo_path / "hosp" / "prescriptions.csv",
{
"20001": [
"11111111111",
"22222222222",
"11111111111",
"0",
"99999999999",
],
"20002": ["33333333333", "", "<NA>"],
},
)
dataset = MIMIC4Dataset(
ehr_root=str(demo_path),
ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
cache_dir=cache_dir.name,
num_workers=1,
)

samples = DrugRecommendationMIMIC4()(dataset.get_patient("10001"))

self.crossmap_load.assert_called_once_with("NDC", "ATC")
self._assert_atc3_samples(samples, "20001", "20002")
self.assertTrue(
all(kwargs == {"level": 3} for _, kwargs in self.mapper.calls)
)


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions tests/core/test_embedding_model_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest

import torch

from pyhealth.datasets import create_sample_dataset
from pyhealth.models import EmbeddingModel


class TestEmbeddingModelPadding(unittest.TestCase):
def setUp(self):
samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"conditions": [["cond-1", "cond-2"], ["cond-3"]],
"deep_codes": [[["deep-1"], ["deep-2", "deep-3"]]],
"label": 1,
},
{
"patient_id": "patient-1",
"visit_id": "visit-1",
"conditions": [["cond-4"]],
"deep_codes": [[["deep-4"]]],
"label": 0,
},
]
self.dataset = create_sample_dataset(
samples=samples,
input_schema={
"conditions": "nested_sequence",
"deep_codes": "deep_nested_sequence",
},
output_schema={"label": "binary"},
dataset_name="embedding-padding-test",
)

def test_nested_sequence_embeddings_use_zero_padding(self):
model = EmbeddingModel(self.dataset, embedding_dim=8)

for field in ["conditions", "deep_codes"]:
embedding = model.embedding_layers[field]
self.assertEqual(embedding.padding_idx, 0)
self.assertTrue(torch.equal(embedding.weight[0], torch.zeros(8)))

def test_nested_sequence_padding_row_does_not_receive_gradients(self):
model = EmbeddingModel(self.dataset, embedding_dim=8)
embedding = model.embedding_layers["conditions"]
token_index = self.dataset.input_processors["conditions"].code_vocab["cond-1"]

output = embedding(torch.tensor([[[0, token_index]]]))
output.sum().backward()

self.assertTrue(torch.equal(embedding.weight.grad[0], torch.zeros(8)))
self.assertGreater(embedding.weight.grad[token_index].abs().sum().item(), 0)


if __name__ == "__main__":
unittest.main()
Loading