diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..4a3de79f2 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Discharge Note Summarization (MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.DischargeNoteSummarization.rst b/docs/api/tasks/pyhealth.tasks.DischargeNoteSummarization.rst new file mode 100644 index 000000000..916a12c9f --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DischargeNoteSummarization.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.DischargeNoteSummarization +======================================= + +.. autoclass:: pyhealth.tasks.discharge_note_summarization.DischargeNoteSummarization + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/discharge__summary_samples.ipynb b/examples/discharge__summary_samples.ipynb new file mode 100644 index 000000000..c11f1bc53 --- /dev/null +++ b/examples/discharge__summary_samples.ipynb @@ -0,0 +1,3152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9QuP-XLwAF3w" + }, + "source": [ + "# Generate cleaned Discharge Summary Samples using DischargeNoteSummarization Task\n", + "\n", + "This notebook demonstrates the usage of MIMIC-IV Note dataset and DischargeNoteSummarizationTask to generate discharge summary samples for LLM training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4Hj9Zi4v2Nis" + }, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eWj28Ms7AEO9" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MUAyQZQnFbbv" + }, + "outputs": [], + "source": [ + "import os\n", + "from pyhealth.datasets import MIMIC4Dataset\n", + "from pyhealth.tasks import BaseTask\n", + "from pyhealth.data import Patient\n", + "from typing import List, Dict, Any\n", + "from pyhealth.processors import TextProcessor\n", + "import argparse\n", + "import random\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "\n", + "pd.options.mode.chained_assignment = None\n", + "import re\n", + "import pickle\n", + "import nltk\n", + "from collections import Counter\n", + "from tqdm import tqdm\n", + "import string" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "738h8qAMA5Fs" + }, + "source": [ + "# Initialize the MIMI4Dataset using the note data downloaded from Physionet website.\n", + "\n", + "Name of dataset used is discharge.csv.gz from Physionet : https://physionet.org/content/ann-pt-summ/1.0.1/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lRIRtrhQNKS2", + "outputId": "886f4286-e412-4e3a-9f91-bd24c329d397" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage Starting MIMIC4Dataset init: 882.8 MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Memory usage Starting MIMIC4Dataset init: 882.8 MB\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing mimic4 dataset from None|/content/drive/MyDrive/llm_data/|None (dev mode: False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Initializing mimic4 dataset from None|/content/drive/MyDrive/llm_data/|None (dev mode: False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No cache_dir provided. Using default cache dir: /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:No cache_dir provided. Using default cache dir: /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing MIMIC4NoteDataset with tables: ['discharge'] (dev mode: False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Initializing MIMIC4NoteDataset with tables: ['discharge'] (dev mode: False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using default note config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_note.yaml\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Using default note config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_note.yaml\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage Before initializing mimic4_note: 882.9 MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/pyhealth/datasets/mimic4.py:121: UserWarning: Events from discharge table only have date timestamp (no specific time). This may affect temporal ordering of events.\n", + " warnings.warn(\n", + "INFO:pyhealth.datasets.mimic4:Memory usage Before initializing mimic4_note: 882.9 MB\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing mimic4_note dataset from /content/drive/MyDrive/llm_data/ (dev mode: False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Initializing mimic4_note dataset from /content/drive/MyDrive/llm_data/ (dev mode: False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using provided cache_dir: /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/cf4117bc-6d03-5673-a78c-162795de42ea\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Using provided cache_dir: /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/cf4117bc-6d03-5673-a78c-162795de42ea\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage After initializing mimic4_note: 883.0 MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Memory usage After initializing mimic4_note: 883.0 MB\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage After Note dataset initialization: 883.0 MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Memory usage After Note dataset initialization: 883.0 MB\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage Completed MIMIC4Dataset init: 883.0 MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Memory usage Completed MIMIC4Dataset init: 883.0 MB\n" + ] + } + ], + "source": [ + "full_note_dataset = MIMIC4Dataset(\n", + " note_root='/content/drive/llm_data/',\n", + " note_tables=[\"discharge\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Usgcv8CU4cj2" + }, + "outputs": [], + "source": [ + "full_note_dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hhcWPLJmBgOQ" + }, + "source": [ + "# Print an event using a patient id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aacWVtZdnsoA" + }, + "outputs": [], + "source": [ + "print(full_note_dataset.get_patient('10000032').get_events())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xyqAklNtB_t0" + }, + "source": [ + "# Define the DischargeNoteSummarization Task\n", + "\n", + "Create DischargeNoteSummarization class , initialize the input and output schema.\n", + "Extract specific sections \"Brief Hospital Course\" and \"Discharge Instructions\". Clean the samples to remove extra spaces and new lines to create a paragraph for each sample text.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "31loQfr6nSRf" + }, + "outputs": [], + "source": [ + "\n", + "from typing import Dict, List, Any, Tuple, Union\n", + "\n", + "class DischargeNoteSummarization(BaseTask):\n", + " task_name: str = \"DischargeNoteSummarization\"\n", + "\n", + " input_schema: Dict[str , str] = {\n", + " \"subject_id\" : \"text\",\n", + " \"hadm_id\": \"text\",\n", + " \"text\": \"text\"\n", + " }\n", + "\n", + " output_schema: Dict[str, str] = {\n", + " \"brief_hospital_course\": \"text\",\n", + " \"summary\": \"text\"\n", + " }\n", + "\n", + "\n", + " def __call__(self, patient: Patient) -> List[Dict[str, Any]]:\n", + " samples = []\n", + " subject_id = patient.patient_id\n", + " for dis in patient.get_events(\"discharge\"):\n", + "\n", + " textNote = dis.attr_dict['text']\n", + " hadm_id = dis.attr_dict['hadm_id']\n", + "\n", + " ## Extract the brief_hospital_course\n", + "\n", + " start = textNote.find(\"Brief Hospital Course:\")\n", + " if start < 0:\n", + " #brief_hospital_course = None\n", + " continue\n", + " end = textNote.find(\"Medications on Admission:\")\n", + " if end == -1:\n", + " end = textNote.find(\"Discharge Medications:\")\n", + " if end == -1:\n", + " end = textNote.find(\"Discharge Disposition:\")\n", + " if end == 0 or start >= end:\n", + " continue\n", + " brief_hospital_course = textNote[start: end].replace('\\n', ' ')\n", + " brief_hospital_course = ' '.join(brief_hospital_course.split())\n", + " # Quality check\n", + " num_words = len(textNote.split(' '))\n", + " \n", + " #extract the summary\n", + " start = textNote.find(\"Discharge Instructions:\")\n", + " end = textNote.find(\"Followup Instructions:\")\n", + " if start < 0 or end < 0:\n", + " continue\n", + " summary = textNote[start: end].replace('\\n', ' ')\n", + " summary = ' '.join(summary.split())\n", + " if len(summary) == 0 or len(summary) < 350:\n", + " continue\n", + " summary = summary.strip()\n", + "\n", + "\n", + "\n", + " samples.append({\n", + " \"text\":textNote,\n", + " \"brief_hospital_course\": brief_hospital_course,\n", + " \"summary\" : summary,\n", + " \"subject_id\" : subject_id,\n", + " \"hadm_id\": hadm_id\n", + " })\n", + "\n", + " return samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8DqgPHBSrXdm" + }, + "outputs": [], + "source": [ + "! rm -r /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/tasks/PatientNoteProcessingTask_46bb372d-34eb-5a38-bd99-ca6f30f0f026/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7iZFJlJKDEyW" + }, + "source": [ + "# Run the Discharge Note Summarization Task\n", + "\n", + "Run the DischargeNoteSummarization Task with 4 workers and note dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 905 + }, + "id": "5Lh3QhOUqCWe", + "outputId": "5ddfc211-28af-4c22-dbda-e04f368d7b2e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task PatientNoteProcessingTask for mimic4 base dataset...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Setting task PatientNoteProcessingTask for mimic4 base dataset...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Task cache paths: task_df=/root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/tasks/PatientNoteProcessingTask_46bb372d-34eb-5a38-bd99-ca6f30f0f026/task_df.ld, samples=/root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/tasks/PatientNoteProcessingTask_46bb372d-34eb-5a38-bd99-ca6f30f0f026/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Task cache paths: task_df=/root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/tasks/PatientNoteProcessingTask_46bb372d-34eb-5a38-bd99-ca6f30f0f026/task_df.ld, samples=/root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/tasks/PatientNoteProcessingTask_46bb372d-34eb-5a38-bd99-ca6f30f0f026/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Applying task transformations on data with 4 workers...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Applying task transformations on data with 4 workers...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Incomplete parquet cache at /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/global_event_df.parquet (directory exists but contains no parquet files). Removing and rebuilding.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:pyhealth.datasets.base_dataset:Incomplete parquet cache at /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/global_event_df.parquet (directory exists but contains no parquet files). Removing and rebuilding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No cached event dataframe found. Creating: /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/global_event_df.parquet\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:No cached event dataframe found. Creating: /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/global_event_df.parquet\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Combining data from note dataset\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Combining data from note dataset\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: discharge from /content/drive/MyDrive/llm_data/note/discharge.csv.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: discharge from /content/drive/MyDrive/llm_data/note/discharge.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating combined dataframe\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.mimic4:Creating combined dataframe\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Caching event dataframe to /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/global_event_df.parquet...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Caching event dataframe to /root/.cache/pyhealth/98de9a11-0af5-5cd9-81f2-2da31802c232/global_event_df.parquet...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Detected Jupyter notebook environment, setting num_workers to 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Detected Jupyter notebook environment, setting num_workers to 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single worker mode, processing sequentially\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Single worker mode, processing sequentially\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Worker 0 started processing 145914 patients. (Polars threads: 2)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Worker 0 started processing 145914 patients. (Polars threads: 2)\n", + " 0%| | 0/145914 [00:00\n" + ] + } + ], + "source": [ + "\n", + "mimic_df = pd.DataFrame(processed_dataset)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iVdEXb_gDdGw" + }, + "source": [ + "# Print the dataframe head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cu-s0F4rkRHm", + "outputId": "5f78d170-f7f5-4136-b1bd-9b10190332e5" + }, + "outputs": [], + "source": [ + "print(mimic_df.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HLrS_fpLkwz8", + "outputId": "791bc340-3f3c-4bef-d733-f499627e8f38" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "740" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(mimic_df.iloc[1]['summary'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IUA4wtrjDtMJ" + }, + "source": [ + "# Perform further processing on the dataframe\n", + "\n", + "Run more data cleaning tasks on the mimic_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "4002f3c605c648409776ec137ccff2bf", + "008bd292f7a9429eb3c7570fbe9fd093", + "46d831dc293444229a86184c910243a0", + "68ae05d25d064e0b8bd39fadb955311e", + "cdd7dd43d43d4bed8d5b401b5e242696", + "20fae152868747c2a175d48936cde9ad", + "a31b0d30e44d469c828cb18393803334", + "66a3ab0b3a484264af941300ac647d6f", + "b9632797c62b454fb347a8fecbb226f7", + "e13239bfb1214cf398f9d1d2b303f105", + "9da0e08960a2462cb9f1ba55656b8116" + ] + }, + "id": "qK-wvZ4D2VLX", + "outputId": "c51d6460-b4d8-4ef9-ab6c-21f828f78b48" + }, + "outputs": [], + "source": [ + "import swifter\n", + "re_service = re.compile(r'^Service: (.*)$', re.IGNORECASE|re.MULTILINE) # Either after Serive:\n", + "re_service_extra = re.compile(r'^Date of Birth:.*Sex:\\s{0,10}\\w\\s{0,10}___: (.*)$', re.IGNORECASE|re.MULTILINE) # Fallback if deidentified\n", + "\n", + "mimic_df['service'] = mimic_df['text'].swifter.apply(lambda s: re_service.search(s).group(1) if re_service.search(s) is not None else None)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "81495e8a1be5469a87d0bb300e82688d", + "db208ab02eb6476aa0d1d3bb4349bd54", + "27653f3de94446a69e8d55de32bb8e6f", + "ee6409daec7b495dae8e08b2f89ffcea", + "280591f53ff74b4cb18851435e5916e2", + "483af853bd7047f4b51ea31f9267276d", + "199dcfce4cf24a8b9153fe2c3c5b4914", + "5dcf36ab5c474cc792718ed10f3d5571", + "52c4ae1316c24e9599e86af27d14b394", + "4e91f191ce6242c9a2fe3f1b1b7aec24", + "3d4731460a9549258137e20e10e2ed75" + ] + }, + "id": "ktnXJRz0a5B4", + "outputId": "16bf4c3b-4aa5-4c89-b837-f7f35dbcffb1" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "81495e8a1be5469a87d0bb300e82688d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Pandas Apply: 0%| | 0/183 [00:00= 3]\n", + "print(f\" Removed {old_len - len(mimic_df)} summaries with less than 3 sentences.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 66, + "referenced_widgets": [ + "5f77f53997744569aa6d09f758026afd", + "ad18bfabf116422a878bb74608daeb2d", + "e27e54ac631644678726ee2edb2b6236", + "43f577596fdf47bb9eb46a372e2bfc84", + "e2071364bdeb49dabbac2da5d7beb505", + "d045bdf233574808a63194e900bc1551", + "b6dcb898478c4920bf0b8215dc15f679", + "243d576b442d47d8ba4e2d1b9f140d7f", + "ce70cdf766a14c6cbe982314afebff47", + "47dd58da2e29401db04ad7f5e7dc80af", + "84c8d429e6fc405fb73aae1b342a116c" + ] + }, + "id": "VaFfqcruyBAv", + "outputId": "967b99b6-f2ac-43d5-dcad-cbc0d0b2e191" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Combine all sentences with single whitespaces.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5f77f53997744569aa6d09f758026afd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Pandas Apply: 0%| | 0/251248 [00:00>> from pyhealth.datasets import MIMIC4Dataset + >>> from pyhealth.tasks import MIMIC4Dataset + >>> dataset = MIMIC4Dataset(note_root=NOTE_ROOT,note_tables=["discharge"]) + >>> task = DataforLlmSummaries() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "DischargeNoteSummarization" + input_schema: Dict[str, str] = { + "subject_id": "text", + "hadm_id": "text", + "text": "text" + } + + output_schema: Dict[str, str] = { + "brief_hospital_course": "text", + "summary": "text" + } + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """ + Generates patient brief_hospital_course and summary samples for a single patient. + + Args: + patient (Patient): A patient object containing at least one 'discharge' event. + + Returns: + List[Dict]: A list containing a dictionary for each patient visit with: + - "text": patient clinical notes text, + - "brief_hospital_course": patient brief hospital course, + - "summary": patient discharge summary text, + - "subject_id": patient identifier, + - "hadm_id": Hospital Admission Identifier, + + """ + samples = [] + subject_id = patient.patient_id + + for dis in patient.get_events("discharge"): + textNote = dis.attr_dict["text"] + hadm_id = dis.attr_dict["hadm_id"] + + # Extract Brief Hospital Course , remove new lines and remove whitespaces to create single paragraph + start = textNote.find("Brief Hospital Course:") + if start < 0: + continue + end = textNote.find("Medications on Admission:") + if end == -1: + end = textNote.find("Discharge Medications:") + if end == -1: + end = textNote.find("Discharge Disposition:") + if end == 0 or start >= end: + continue + brief_hospital_course = textNote[start:end].replace("\n", " ") + brief_hospital_course = " ".join(brief_hospital_course.split()) + + # Extract Discharge Instructions (summary) and filter out samples less than MIN_SUMMARY_LENGTH + start = textNote.find("Discharge Instructions:") + end = textNote.find("Followup Instructions:") + if start >= 0 and end >= 0: + summary = textNote[start:end].replace("\n", " ") + summary = " ".join(summary.split()) + + summary = summary.strip() + #Only add to samples if length of summary greater than specified MIN_SUMMARY_LENGTH + if len(summary) >= MIN_SUMMARY_LENGTH: + samples.append({ + "text": textNote, + "brief_hospital_course": brief_hospital_course, + "summary": summary, + "subject_id": subject_id, + "hadm_id": hadm_id, + }) + + return samples \ No newline at end of file diff --git a/tests/core/test_discharge_note_summarization.py b/tests/core/test_discharge_note_summarization.py new file mode 100644 index 000000000..a2a8c7a19 --- /dev/null +++ b/tests/core/test_discharge_note_summarization.py @@ -0,0 +1,161 @@ +""" +Unit tests for DischargeNoteSummarization task in summarization_data_processing.py. + +Tests cover: + - Class attributes (task_name, input_schema, output_schema) + - __call__: happy-path extraction of brief_hospital_course and summary + - __call__: all boundary / filtering conditions that cause samples to be skipped + - Output dictionary structure and field types + +External dependencies (pyhealth) are fully mocked so the tests run without +installing the real library or accessing any dataset. + +Run with: + python -m pytest test_summarization_data_processing.py -v + # or + python -m unittest test_summarization_data_processing.py -v +""" + + +import unittest +from unittest.mock import MagicMock, patch +from pathlib import Path +from pathlib import Path +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.data import Patient +import tempfile +from pyhealth.tasks import DischargeNoteSummarization +from unittest.mock import MagicMock +from pyhealth.data import Patient, Event + + +import logging + +class TestDischargeNoteSummarizationTask(unittest.TestCase): + @classmethod + def setUpClass(cls): + #cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "discharge" + cls.cache_dir = tempfile.TemporaryDirectory() + #cls.full_note_dataset = MIMIC4Dataset( + # note_root=cls.test_resources, + # note_tables=["discharge"]) + cls.task = DischargeNoteSummarization() + #cls.sample_notes = cls.full_note_dataset.set_task(cls.task) + cls.MIN_SUMMARY_LENGTH = 350 + + def create_mock_patient(self, note_text, patient_id="p1", hadm_id="h1", subject_id="20000003"): + """Helper to create a mock Patient with a single discharge event.""" + patient = MagicMock(spec=Patient) + patient.patient_id = patient_id + + # Create a mock Event for the discharge note + event = MagicMock(spec=Event) + event.attr_dict = { + "text": note_text, + "hadm_id": hadm_id, + "subject_id": subject_id + } + + # Mock the get_events method to return our discharge event + patient.get_events.side_effect = lambda event_type: [event] if event_type == "discharge" else [] + return patient + + + #def test_generated_samples(self): + # self.assertEqual(len(self.sample_notes), 2) + # self.assertTrue(self.sample_notes[0]["summary"].startswith("Discharge Instructions:")) + + + + + def test_task_metadata(self): + self.assertEqual(self.task.task_name,"DischargeNoteSummarization") + self.assertIn("text", self.task.input_schema) + self.assertIn("summary", self.task.output_schema) + + def test_filtering_short_summary(self): + + note = ( + "Brief Hospital Course:\n" + "The patient is an elderly individual with a significant past medical history of chronic obstructive " + "pulmonary disease, congestive heart failure with a reduced ejection fraction of thirty-five percent, " + "and Type 2 diabetes mellitus. The patient presented to the emergency department complaining of " + "progressive shortness of breath, productive cough with yellow sputum, and bilateral lower extremity " + "edema increasing over the last five days. Upon arrival, the patient was tachycardic and hypoxic, " + "requiring supplemental oxygen via nasal cannula to maintain saturations above ninety-two percent. " + "A chest X-ray revealed bilateral pulmonary infiltrates and pleural effusions, consistent with a " + "multifocal pneumonia overlaying a congestive heart failure exacerbation. Laboratory results were " + "significant for an elevated pro-BNP and a leukocytosis with an elevated white blood cell count. " + "During the first forty-eight hours of admission, the patient was started on intravenous antibiotics " + "for community-acquired pneumonia. Diuresis was initiated with intravenous medications, resulting in " + "a significant net negative fluid balance over three days. The patient’s respiratory status " + "improved significantly; oxygen was successfully weaned to room air by hospital day four. " + "Endocrinology was consulted for blood glucose management, and the insulin regimen was " + "adjusted to a sliding scale with a long-acting basal dose. By the day of discharge, the " + "patient was stable, ambulating without distress, and lung sounds were markedly clearer on " + "auscultation. Weight had returned to the documented baseline. " + + "Medications on Admission: " + "Metformin, Lisinopril, Furosemide, and an Albuterol inhaler. " + + "Discharge Instructions: " + "You were treated in the hospital for a combination of pneumonia and a flare-up of your heart " + "failure. It is vital that you finish the entire course of oral antibiotics as prescribed, " + "even if you feel better. Please monitor your weight every morning before breakfast. If you " + "notice a weight gain of more than three pounds in a single day or five pounds in a week, " + "contact your primary care doctor immediately as this indicates fluid buildup. Continue to " + "use your salt-restricted diet and limit your total fluid intake to one and a half liters " + "daily to prevent further strain on your heart. Rest is encouraged for the next week; however, " + "try to perform light walking around the house to prevent blood clots. Avoid any heavy lifting " + "or strenuous exercise until cleared by your cardiologist. You should continue your home " + "medications as updated in the attached list. Seek immediate emergency care if you experience " + "chest pain, severe shortness of breath while sitting still, or if you begin coughing up blood. " + "We have adjusted your diuretic medication slightly to help manage your fluid levels more " + "effectively during your recovery. Ensure you have picked up your new prescriptions from the " + "pharmacy before the end of the day. It is also recommended that you receive your flu and " + "pneumonia vaccinations once you have fully recovered from this current illness. Please bring " + "your updated medication list to all upcoming appointments to ensure your medical record is accurate. " + + "Followup Instructions: " + "Follow up with Cardiology next week. Follow up with your Primary Care Provider within seven days " + "for a transition of care visit." + + ) + patient = self.create_mock_patient(note) + samples = self.task(patient) + + self.assertEqual(len(samples), 1, "This summary should not be filtered out as its length more than 350.") + + def test_edge_cases(self): + """Verify that summaries shorter than MIN_SUMMARY_LENGTH (350) are skipped.""" + short_summary = "This summary is too short." # ~26 chars + note = ( + #"Brief Hospital Course:\nStable.\n" + "Medications on Admission:\nNone.\n" + "Discharge Instructions:\n" + short_summary + "\n" + "Followup Instructions:\nNone." + ) + patient = self.create_mock_patient(note) + samples = self.task(patient) + + self.assertEqual(len(samples), 0, "Should filter out samples with short summaries.") + + def test_edge_cases_1(self): + short_summary = "This is a sample generated summary." + note = ( + "Brief Hospital Course:\nStable.\n" + #"Medications on Admission:\nNone.\n" + #"Discharge Instructions:\n" + short_summary + "\n" + #"Followup Instructions:\nNone." + "This is a sample generated short summary that coes not contain all sections." + ) + + patient = self.create_mock_patient(note) + samples = self.task(patient) + + self.assertEqual(len(samples), 0, "Should filter out samples with short summaries.") + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file