From a99f966307fb12f820628e76eb78d5369dd6e552 Mon Sep 17 00:00:00 2001 From: Rahul D Date: Sun, 19 Apr 2026 12:21:26 -0400 Subject: [PATCH 01/10] updated dataset.rst --- docs/api/datasets.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..7ade42978 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -243,5 +243,6 @@ Available Datasets datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset + datasets/pyhealth.datasets.CaReSoundDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils From cef12cca62820b5a0d3d35fabe07e98521125149 Mon Sep 17 00:00:00 2001 From: Rahul D Date: Sun, 19 Apr 2026 12:39:58 -0400 Subject: [PATCH 02/10] Adding dataset docs --- docs/api/datasets.rst | 2 +- .../datasets/pyhealth.datasets.CaReSoundDataset.rst | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 7ade42978..6527ad3a4 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -236,6 +236,7 @@ Available Datasets datasets/pyhealth.datasets.EHRShotDataset datasets/pyhealth.datasets.Support2Dataset datasets/pyhealth.datasets.BMDHSDataset + datasets/pyhealth.datasets.CaReSoundDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset datasets/pyhealth.datasets.TUABDataset @@ -243,6 +244,5 @@ Available Datasets datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset - datasets/pyhealth.datasets.CaReSoundDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils diff --git a/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst b/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst new file mode 100644 index 000000000..821842b69 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.CaReSoundDataset +================================ + +The CaReSound dataset provides question and answer pairs for medical sounds. For more information see `CaReSound `_. This dataset was contributed as part of the CaReAQA: A Cardiac and Respiratory Audio Question Answering Model for Open-Ended Diagnostic Reasoning work (`arXiv:2505.01199 `_). + +.. autoclass:: pyhealth.datasets.CaReSoundDataset + :members: + :undoc-members: + :show-inheritance: + + \ No newline at end of file From 08775df3888ba27f050b3c4e6201c66dcabdebda Mon Sep 17 00:00:00 2001 From: Rahul D Date: Sun, 19 Apr 2026 14:42:55 -0400 Subject: [PATCH 03/10] Added dataset and task --- docs/api/tasks.rst | 1 + docs/api/tasks/pyhealth.tasks.CaReSoundAQA | 7 + pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/caresound.py | 200 +++++++++++++++++++++ pyhealth/datasets/configs/caresound.yaml | 11 ++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/caresound_tasks.py | 64 +++++++ 7 files changed, 285 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.CaReSoundAQA create mode 100644 pyhealth/datasets/caresound.py create mode 100644 pyhealth/datasets/configs/caresound.yaml create mode 100644 pyhealth/tasks/caresound_tasks.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..1d9aef545 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + CaReSound diff --git a/docs/api/tasks/pyhealth.tasks.CaReSoundAQA b/docs/api/tasks/pyhealth.tasks.CaReSoundAQA new file mode 100644 index 000000000..f4b4416b0 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.CaReSoundAQA @@ -0,0 +1,7 @@ +pyhealth.tasks.CaReSoundAQA +============================================== + +.. autoclass:: pyhealth.tasks.CaReSoundAQA + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..26e872638 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -48,6 +48,7 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset +from .caresound import CaReSoundDataset from .chestxray14 import ChestXray14Dataset from .clinvar import ClinVarDataset from .cosmic import COSMICDataset diff --git a/pyhealth/datasets/caresound.py b/pyhealth/datasets/caresound.py new file mode 100644 index 000000000..9f9412b78 --- /dev/null +++ b/pyhealth/datasets/caresound.py @@ -0,0 +1,200 @@ +"""CaReSound dataset for PyHealth. + +This module provides the CaReSoundDataset class for loading and processing +the CaReSound benchmark data for Audio Question Answering (AQA) tasks. +""" +import logging +import os +from pathlib import Path +from typing import List, Optional, Dict, Any + +import pandas as pd +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + +class CaReSoundDataset(BaseDataset): + """CaReSound dataset for open-ended diagnostic reasoning. + + This dataset aggregates five medical audio sources: ICBHI, KAUH, + CirCor, SPRSound, and ZCHSound. It pairs respiratory and cardiac + audio with 34,792 GPT-4o generated Question-Answer pairs. + + Args: + root: Root directory containing the audio files (.wav) and/or CSVs. + tables: Optional list of tables to load. Defaults to ["metadata"]. + dataset_name: Optional name of the dataset. Defaults to "caresound". + config_path: Optional path to the configuration file. + """ + + def __init__( + self, + root: str, + tables: List[str] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "caresound.yaml" + + # 1. Prepare standardized CSV (handles all local/API edge cases) + pyhealth_csv = os.path.join(root, "caresound_metadata.csv") + if not os.path.exists(pyhealth_csv): + self.prepare_metadata(root) + + # 2. Resolve local audio paths dynamically + self.audio_path_map = self._resolve_audio_paths(root) + + # 3. Define the default table mapped in the YAML + default_tables = ["metadata"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "caresound", + config_path=config_path, + **kwargs, + ) + + @staticmethod + def prepare_metadata(root: str) -> None: + """Prepares QA metadata from local ZIP/CSV drops or downloads via HF API.""" + output_path = os.path.join(root, "caresound_metadata.csv") + + train_csv = os.path.join(root, "CaReSoundQA_train.csv") + test_csv = os.path.join(root, "CaReSoundQA_test.csv") + full_csv = os.path.join(root, "CaReSoundQA.csv") + + # Scenario A: User manually downloaded both Train and Test CSVs + if os.path.exists(train_csv) and os.path.exists(test_csv): + logger.info("Found local train/test CSVs. Merging...") + df_train, df_test = pd.read_csv(train_csv), pd.read_csv(test_csv) + df_train['hf_split'], df_test['hf_split'] = 'train', 'test' + df_master = pd.concat([df_train, df_test], ignore_index=True) + df_master.to_csv(output_path, index=False) + return + + # Scenario B: User manually downloaded ONLY Train CSV + elif os.path.exists(train_csv): + logger.warning("Found local train CSV, but test is missing. Using train only.") + df_master = pd.read_csv(train_csv) + df_master['hf_split'] = 'train' + df_master.to_csv(output_path, index=False) + return + + # Scenario C: User manually downloaded the Master CSV + elif os.path.exists(full_csv): + logger.info(f"Found master CSV: {full_csv}.") + df_master = pd.read_csv(full_csv) + if 'hf_split' not in df_master.columns: + df_master['hf_split'] = 'unknown' + df_master.to_csv(output_path, index=False) + return + + # Scenario D: Fallback to Hugging Face API + try: + from datasets import load_dataset + logger.info("Local metadata not found. Fetching from tsnngw/CaReSound...") + dataset = load_dataset("tsnngw/CaReSound") + + df_train = dataset['train'].to_pandas() + df_train['hf_split'] = 'train' + + # Catch in case the dataset structure changes on HF + if 'test' in dataset: + df_test = dataset['test'].to_pandas() + df_test['hf_split'] = 'test' + df_master = pd.concat([df_train, df_test], ignore_index=True) + else: + df_master = df_train + + df_master.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_master)} QA pairs via API to {output_path}") + + except ImportError: + logger.error("The 'datasets' library is required. Run: pip install datasets") + raise + except Exception as e: + logger.error(f"Failed to fetch metadata: {e}") + raise + + def _resolve_audio_paths(self, root: str) -> Dict[tuple, str]: + """Maps .wav files using robust stem and prefix mapping.""" + audio_map = {} + wav_files = list(Path(root).rglob("*.wav")) + + if not wav_files: + logger.warning(f"No .wav files found in {root}.") + return audio_map + + for path in wav_files: + stem = path.stem + path_str = str(path).lower() + + source = "Unknown" + if "icbhi" in path_str: source = "ICBHI" + elif "circor" in path_str: source = "CirCor" + elif "kauh" in path_str: source = "KAUH" + elif "spr" in path_str: source = "SPRSound" + elif "zch" in path_str: source = "ZCHSound" + + # 1. Primary Mapping: Exact filename match + audio_map[(source, stem)] = str(path.absolute()) + + # 2. Fallback Mapping: Base Patient ID match (e.g., '101' from '101_1b1') + base_id = stem.split('_')[0] + if (source, base_id) not in audio_map: + audio_map[(source, base_id)] = str(path.absolute()) + + return audio_map + + def parse_func(self) -> Dict[str, Any]: + """Merges tabular QA metadata with local audio paths.""" + csv_path = os.path.join(self.root, "caresound_metadata.csv") + df = pd.read_csv(csv_path) + + patients = {} + missing_sources = set() + + for _, row in df.iterrows(): + pid = str(row['patient_id']) + source = str(row['dataset']) + + audio_path = self.audio_path_map.get((source, pid)) + + if not audio_path: + missing_sources.add(source) + continue + + if pid not in patients: + patients[pid] = {"patient_id": pid, "visits": {}} + + visit_id = f"{source}_{pid}" + if visit_id not in patients[pid]["visits"]: + patients[pid]["visits"][visit_id] = { + "visit_id": visit_id, + "audio_path": audio_path, + "events": [] + } + + patients[pid]["visits"][visit_id]["events"].append({ + "question": row.get('question', ''), + "answer": row.get('answer', ''), + "hf_split": row.get('hf_split', 'unknown') + }) + + if missing_sources: + logger.warning( + f"Audio files missing for datasets: {', '.join(missing_sources)}. " + "Only available multi-modal samples have been loaded." + ) + + return patients + + @property + def default_task(self): + from pyhealth.tasks import CaReSoundAQA + return CaReSoundAQA() \ No newline at end of file diff --git a/pyhealth/datasets/configs/caresound.yaml b/pyhealth/datasets/configs/caresound.yaml new file mode 100644 index 000000000..a902b1e70 --- /dev/null +++ b/pyhealth/datasets/configs/caresound.yaml @@ -0,0 +1,11 @@ +version: "1.0" +tables: + metadata: + file_path: "caresound_metadata.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "dataset" + - "question" + - "answer" + - "hf_split" \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..d38c313eb 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,6 +1,7 @@ from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction +from .caresound_tasks import CaReSoundAQA from .bmd_hs_disease_classification import BMDHSDiseaseClassification from .cardiology_detect import ( cardiology_isAD_fn, diff --git a/pyhealth/tasks/caresound_tasks.py b/pyhealth/tasks/caresound_tasks.py new file mode 100644 index 000000000..4ae48dd86 --- /dev/null +++ b/pyhealth/tasks/caresound_tasks.py @@ -0,0 +1,64 @@ +"""Audio Question Answering tasks for PyHealth. + +This module provides tasks for processing generative text answers +based on medical audio signals using the CaReSound dataset. +""" + +from typing import Any, Dict, List +from .base_task import BaseTask + + +class CaReSoundAQA(BaseTask): + """Task for Audio Question Answering on respiratory and cardiac sounds. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): Required inputs (audio_path, question). + output_schema (Dict[str, str]): Required outputs (answer). + """ + + task_name: str = "CaReSoundAQA" + input_schema: Dict[str, str] = { + "audio_path": "path", + "question": "text", + } + output_schema: Dict[str, str] = {"answer": "text"} + + def _safe_str(self, value: Any, default: str = "") -> str: + """Safely convert value to string, handling None and NaN.""" + if value is None or str(value).lower() == "nan": + return default + return str(value) + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a patient record to extract audio-QA samples.""" + samples: List[Dict[str, Any]] = [] + + for visit_id, visit in patient.visits.items(): + audio_path = visit.attr_dict.get("audio_path") if hasattr(visit, "attr_dict") else visit.get("audio_path") + + if not audio_path: + continue + + events = visit.attr_dict.get("events", []) if hasattr(visit, "attr_dict") else visit.get("events", []) + + for event in events: + question = self._safe_str(event.get("question")) + answer = self._safe_str(event.get("answer")) + hf_split = self._safe_str(event.get("hf_split"), default="unknown") + + if not question or not answer: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": visit_id, + "audio_path": audio_path, + "question": question, + "answer": answer, + "original_hf_split": hf_split, + } + ) + + return samples \ No newline at end of file From 58d2d220a941694cb5b204a56f6e405b235936d4 Mon Sep 17 00:00:00 2001 From: Rahul D Date: Sun, 19 Apr 2026 17:32:56 -0400 Subject: [PATCH 04/10] Fixed dataset issues --- pyhealth/datasets/caresound.py | 77 +++++++++++++++--------- pyhealth/datasets/configs/caresound.yaml | 3 +- pyhealth/tasks/caresound_tasks.py | 45 ++++++-------- 3 files changed, 68 insertions(+), 57 deletions(-) diff --git a/pyhealth/datasets/caresound.py b/pyhealth/datasets/caresound.py index 9f9412b78..9578ea6b4 100644 --- a/pyhealth/datasets/caresound.py +++ b/pyhealth/datasets/caresound.py @@ -68,22 +68,20 @@ def prepare_metadata(root: str) -> None: test_csv = os.path.join(root, "CaReSoundQA_test.csv") full_csv = os.path.join(root, "CaReSoundQA.csv") + df_master = None + # Scenario A: User manually downloaded both Train and Test CSVs if os.path.exists(train_csv) and os.path.exists(test_csv): logger.info("Found local train/test CSVs. Merging...") df_train, df_test = pd.read_csv(train_csv), pd.read_csv(test_csv) df_train['hf_split'], df_test['hf_split'] = 'train', 'test' df_master = pd.concat([df_train, df_test], ignore_index=True) - df_master.to_csv(output_path, index=False) - return # Scenario B: User manually downloaded ONLY Train CSV elif os.path.exists(train_csv): logger.warning("Found local train CSV, but test is missing. Using train only.") df_master = pd.read_csv(train_csv) df_master['hf_split'] = 'train' - df_master.to_csv(output_path, index=False) - return # Scenario C: User manually downloaded the Master CSV elif os.path.exists(full_csv): @@ -91,35 +89,54 @@ def prepare_metadata(root: str) -> None: df_master = pd.read_csv(full_csv) if 'hf_split' not in df_master.columns: df_master['hf_split'] = 'unknown' - df_master.to_csv(output_path, index=False) - return # Scenario D: Fallback to Hugging Face API - try: - from datasets import load_dataset - logger.info("Local metadata not found. Fetching from tsnngw/CaReSound...") - dataset = load_dataset("tsnngw/CaReSound") - - df_train = dataset['train'].to_pandas() - df_train['hf_split'] = 'train' - - # Catch in case the dataset structure changes on HF - if 'test' in dataset: - df_test = dataset['test'].to_pandas() - df_test['hf_split'] = 'test' - df_master = pd.concat([df_train, df_test], ignore_index=True) - else: - df_master = df_train - - df_master.to_csv(output_path, index=False) - logger.info(f"Saved {len(df_master)} QA pairs via API to {output_path}") + else: + try: + from datasets import load_dataset + logger.info("Local metadata not found. Fetching from tsnngw/CaReSound...") + dataset = load_dataset("tsnngw/CaReSound") + + df_train = dataset['train'].to_pandas() + df_train['hf_split'] = 'train' + + # Catch in case the dataset structure changes on HF + if 'test' in dataset: + df_test = dataset['test'].to_pandas() + df_test['hf_split'] = 'test' + df_master = pd.concat([df_train, df_test], ignore_index=True) + else: + df_master = df_train + + except ImportError: + logger.error("The 'datasets' library is required. Run: pip install datasets") + raise + except Exception as e: + logger.error(f"Failed to fetch metadata: {e}") + raise + + # ---> MINIMAL FIX: Inject audio paths right before saving <--- + audio_map = {} + for path in Path(root).rglob("*.wav"): + stem, path_str = path.stem, str(path).lower() + source = "Unknown" + if "icbhi" in path_str: source = "ICBHI" + elif "circor" in path_str: source = "CirCor" + elif "kauh" in path_str: source = "KAUH" + elif "spr" in path_str: source = "SPRSound" + elif "zch" in path_str: source = "ZCHSound" - except ImportError: - logger.error("The 'datasets' library is required. Run: pip install datasets") - raise - except Exception as e: - logger.error(f"Failed to fetch metadata: {e}") - raise + audio_map[(source, stem)] = str(path.absolute()) + audio_map[(source, stem.split('_')[0])] = str(path.absolute()) + + df_master['audio_path'] = df_master.apply( + lambda r: audio_map.get((str(r.get('dataset', 'Unknown')), str(r.get('patient_id', ''))), ""), + axis=1 + ) + + # Save the final CSV for the new PyHealth Engine to pick up automatically + df_master.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_master)} QA pairs with mapped audio to {output_path}") def _resolve_audio_paths(self, root: str) -> Dict[tuple, str]: """Maps .wav files using robust stem and prefix mapping.""" diff --git a/pyhealth/datasets/configs/caresound.yaml b/pyhealth/datasets/configs/caresound.yaml index a902b1e70..081b9b2f5 100644 --- a/pyhealth/datasets/configs/caresound.yaml +++ b/pyhealth/datasets/configs/caresound.yaml @@ -8,4 +8,5 @@ tables: - "dataset" - "question" - "answer" - - "hf_split" \ No newline at end of file + - "hf_split" + - "audio_path" \ No newline at end of file diff --git a/pyhealth/tasks/caresound_tasks.py b/pyhealth/tasks/caresound_tasks.py index 4ae48dd86..f0d24f73e 100644 --- a/pyhealth/tasks/caresound_tasks.py +++ b/pyhealth/tasks/caresound_tasks.py @@ -16,10 +16,8 @@ class CaReSoundAQA(BaseTask): input_schema (Dict[str, str]): Required inputs (audio_path, question). output_schema (Dict[str, str]): Required outputs (answer). """ - task_name: str = "CaReSoundAQA" input_schema: Dict[str, str] = { - "audio_path": "path", "question": "text", } output_schema: Dict[str, str] = {"answer": "text"} @@ -33,32 +31,27 @@ def _safe_str(self, value: Any, default: str = "") -> str: def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient record to extract audio-QA samples.""" samples: List[Dict[str, Any]] = [] - - for visit_id, visit in patient.visits.items(): - audio_path = visit.attr_dict.get("audio_path") if hasattr(visit, "attr_dict") else visit.get("audio_path") + events = patient.get_events("metadata") + + for event in events: + # The new engine puts CSV columns into attr_dict + attr = getattr(event, "attr_dict", {}) - if not audio_path: - continue - - events = visit.attr_dict.get("events", []) if hasattr(visit, "attr_dict") else visit.get("events", []) + audio_path = str(attr.get("audio_path", "")) + question = str(attr.get("question", "")) + answer = str(attr.get("answer", "")) + hf_split = str(attr.get("hf_split", "unknown")) - for event in events: - question = self._safe_str(event.get("question")) - answer = self._safe_str(event.get("answer")) - hf_split = self._safe_str(event.get("hf_split"), default="unknown") - - if not question or not answer: - continue + if not audio_path or not question or not answer: + continue - samples.append( - { - "patient_id": patient.patient_id, - "visit_id": visit_id, - "audio_path": audio_path, - "question": question, - "answer": answer, - "original_hf_split": hf_split, - } - ) + samples.append({ + "patient_id": patient.patient_id, + "visit_id": f"v_{patient.patient_id}", + "audio_path": audio_path, + "question": question, + "answer": answer, + "original_hf_split": hf_split, + }) return samples \ No newline at end of file From f78dbe63927c19b3a3d8b2e4abea7400fb0e8607 Mon Sep 17 00:00:00 2001 From: Rahul D Date: Mon, 20 Apr 2026 16:48:14 -0400 Subject: [PATCH 05/10] Added test and test resource --- .../caresound/datasets/CaReSoundQA_train.csv | 13 +++ tests/core/test_caresound.py | 87 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 test-resources/caresound/datasets/CaReSoundQA_train.csv create mode 100644 tests/core/test_caresound.py diff --git a/test-resources/caresound/datasets/CaReSoundQA_train.csv b/test-resources/caresound/datasets/CaReSoundQA_train.csv new file mode 100644 index 000000000..974d3c46a --- /dev/null +++ b/test-resources/caresound/datasets/CaReSoundQA_train.csv @@ -0,0 +1,13 @@ +patient_id,question,answer,dataset +65109516,Were any abnormal lung sounds noted during auscultation?,"No, the lungs were normal during auscultation.",SPRSound +ZCH0810,What specific type of defect is indicated in the diagnosis?,A ventricular septal defect is indicated in the diagnosis.,ZCHSound +147,What is the diagnosis based on the auscultation findings?,COPD,ICBHI +159,Are crackles present in the anterior right chest location?,"No, crackles are not present in the anterior right chest location.",ICBHI +85172,Is the murmur heard more prominently at any particular valve area?,"Yes, the murmur is most audible at the pulmonic valve area.",CirCor +BP50,Where was the normal respiratory sound heard?,Posterior Right Lower,KAUH +DP83,Where was the sound located during auscultation?,Anterior Right Upper,KAUH +EP31,Where is the location of the auscultation?,Posterior Lower Middle,KAUH +ZCH1062,Is there any abnormality detected in the cardiac auscultation findings?,"No, the cardiac auscultation findings are normal.",ZCHSound +154,Is there evidence of wheezing in the patient's auscultation?,"No, there is no evidence of wheezing.",ICBHI +ZCH0125,Are there any abnormalities in the heart sounds?,"No, the heart sounds are normal.",ZCHSound +203,What is the diagnosis based on auscultation?,COPD,ICBHI \ No newline at end of file diff --git a/tests/core/test_caresound.py b/tests/core/test_caresound.py new file mode 100644 index 000000000..aa125d7fb --- /dev/null +++ b/tests/core/test_caresound.py @@ -0,0 +1,87 @@ +import os +import unittest +from pathlib import Path + +from pyhealth.datasets import CaReSoundDataset +from pyhealth.tasks import CaReSoundAQA + + +class TestCaReSoundDataset(unittest.TestCase): + """Test cases for the CaReSoundDataset.""" + + @classmethod + def setUpClass(cls): + """Set up test resources path pointing to the PyHealthCS598/test-resources folder.""" + # This navigates up from tests/datasets/test_caresound.py to the project root + cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "caresound" / "datasets" + + # Ensure the directory actually exists to prevent confusing errors + if not cls.test_resources.exists(): + raise FileNotFoundError( + f"Test resources not found at {cls.test_resources}. " + "Please ensure your sample audio and CaReSoundQA.csv are placed there." + ) + + def test_dataset_initialization(self): + """Test that the dataset initializes correctly from the test-resources folder.""" + dataset = CaReSoundDataset(root=str(self.test_resources)) + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "caresound") + + def test_stats(self): + """Test that stats() runs without error.""" + dataset = CaReSoundDataset(root=str(self.test_resources)) + import sys, io + captured_output = io.StringIO() + sys.stdout = captured_output + dataset.stats() + sys.stdout = sys.__stdout__ + + # Updated to match the actual PyHealth output format! + self.assertIn("Dataset: caresound", captured_output.getvalue()) + + def test_default_task(self): + """Test that the default task is properly assigned to CaReSoundAQA.""" + dataset = CaReSoundDataset(root=str(self.test_resources)) + self.assertIsInstance(dataset.default_task, CaReSoundAQA) + + def test_set_task(self): + """Test applying the CaReSoundAQA task to the dataset.""" + dataset = CaReSoundDataset(root=str(self.test_resources)) + task = CaReSoundAQA() + samples = dataset.set_task(task) + + # Ensure the task actually generated samples + self.assertGreater(len(samples), 0) + + # Verify the schema of the first sample + sample = samples[0] + self.assertIn("patient_id", sample) + self.assertIn("question", sample) + self.assertIn("answer", sample) + self.assertIn("audio_path", sample) + + +class TestCaReSoundAQA(unittest.TestCase): + """Test cases for the CaReSoundAQA task schema and utilities.""" + + def setUp(self): + self.task = CaReSoundAQA() + + def test_task_attributes(self): + """Test task class attributes.""" + self.assertEqual(self.task.task_name, "CaReSoundAQA") + self.assertIn("question", self.task.input_schema) + self.assertEqual(self.task.input_schema["question"], "text") + self.assertIn("answer", self.task.output_schema) + self.assertEqual(self.task.output_schema["answer"], "text") + + def test_safe_str(self): + """Test the string safety utility.""" + self.assertEqual(self.task._safe_str("Hello"), "Hello") + self.assertEqual(self.task._safe_str(None, default="N/A"), "N/A") + self.assertEqual(self.task._safe_str("nan", default="missing"), "missing") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From b2c9fda630c8895fe62e822824e413923b4bc5ec Mon Sep 17 00:00:00 2001 From: Rahul D Date: Tue, 21 Apr 2026 11:44:46 -0400 Subject: [PATCH 06/10] PEP8 formatting --- pyhealth/datasets/caresound.py | 142 ++++++++++++++++++------------ pyhealth/tasks/caresound_tasks.py | 31 ++++--- 2 files changed, 101 insertions(+), 72 deletions(-) diff --git a/pyhealth/datasets/caresound.py b/pyhealth/datasets/caresound.py index 9578ea6b4..8c7abc6f6 100644 --- a/pyhealth/datasets/caresound.py +++ b/pyhealth/datasets/caresound.py @@ -3,6 +3,7 @@ This module provides the CaReSoundDataset class for loading and processing the CaReSound benchmark data for Audio Question Answering (AQA) tasks. """ + import logging import os from pathlib import Path @@ -13,11 +14,12 @@ logger = logging.getLogger(__name__) + class CaReSoundDataset(BaseDataset): """CaReSound dataset for open-ended diagnostic reasoning. - This dataset aggregates five medical audio sources: ICBHI, KAUH, - CirCor, SPRSound, and ZCHSound. It pairs respiratory and cardiac + This dataset aggregates five medical audio sources: ICBHI, KAUH, + CirCor, SPRSound, and ZCHSound. It pairs respiratory and cardiac audio with 34,792 GPT-4o generated Question-Answer pairs. Args: @@ -63,53 +65,60 @@ def __init__( def prepare_metadata(root: str) -> None: """Prepares QA metadata from local ZIP/CSV drops or downloads via HF API.""" output_path = os.path.join(root, "caresound_metadata.csv") - + train_csv = os.path.join(root, "CaReSoundQA_train.csv") test_csv = os.path.join(root, "CaReSoundQA_test.csv") full_csv = os.path.join(root, "CaReSoundQA.csv") - + df_master = None - + # Scenario A: User manually downloaded both Train and Test CSVs if os.path.exists(train_csv) and os.path.exists(test_csv): logger.info("Found local train/test CSVs. Merging...") df_train, df_test = pd.read_csv(train_csv), pd.read_csv(test_csv) - df_train['hf_split'], df_test['hf_split'] = 'train', 'test' + df_train["hf_split"], df_test["hf_split"] = "train", "test" df_master = pd.concat([df_train, df_test], ignore_index=True) # Scenario B: User manually downloaded ONLY Train CSV elif os.path.exists(train_csv): - logger.warning("Found local train CSV, but test is missing. Using train only.") + logger.warning( + "Found local train CSV, but test is missing. Using train only." + ) df_master = pd.read_csv(train_csv) - df_master['hf_split'] = 'train' + df_master["hf_split"] = "train" # Scenario C: User manually downloaded the Master CSV elif os.path.exists(full_csv): logger.info(f"Found master CSV: {full_csv}.") df_master = pd.read_csv(full_csv) - if 'hf_split' not in df_master.columns: - df_master['hf_split'] = 'unknown' + if "hf_split" not in df_master.columns: + df_master["hf_split"] = "unknown" # Scenario D: Fallback to Hugging Face API else: try: from datasets import load_dataset - logger.info("Local metadata not found. Fetching from tsnngw/CaReSound...") + + logger.info( + "Local metadata not found. Fetching from tsnngw/CaReSound..." + ) dataset = load_dataset("tsnngw/CaReSound") - - df_train = dataset['train'].to_pandas() - df_train['hf_split'] = 'train' - + + df_train = dataset["train"].to_pandas() + df_train["hf_split"] = "train" + # Catch in case the dataset structure changes on HF - if 'test' in dataset: - df_test = dataset['test'].to_pandas() - df_test['hf_split'] = 'test' + if "test" in dataset: + df_test = dataset["test"].to_pandas() + df_test["hf_split"] = "test" df_master = pd.concat([df_train, df_test], ignore_index=True) else: df_master = df_train - + except ImportError: - logger.error("The 'datasets' library is required. Run: pip install datasets") + logger.error( + "The 'datasets' library is required. Run: pip install datasets" + ) raise except Exception as e: logger.error(f"Failed to fetch metadata: {e}") @@ -120,29 +129,38 @@ def prepare_metadata(root: str) -> None: for path in Path(root).rglob("*.wav"): stem, path_str = path.stem, str(path).lower() source = "Unknown" - if "icbhi" in path_str: source = "ICBHI" - elif "circor" in path_str: source = "CirCor" - elif "kauh" in path_str: source = "KAUH" - elif "spr" in path_str: source = "SPRSound" - elif "zch" in path_str: source = "ZCHSound" - + if "icbhi" in path_str: + source = "ICBHI" + elif "circor" in path_str: + source = "CirCor" + elif "kauh" in path_str: + source = "KAUH" + elif "spr" in path_str: + source = "SPRSound" + elif "zch" in path_str: + source = "ZCHSound" + audio_map[(source, stem)] = str(path.absolute()) - audio_map[(source, stem.split('_')[0])] = str(path.absolute()) + audio_map[(source, stem.split("_")[0])] = str(path.absolute()) - df_master['audio_path'] = df_master.apply( - lambda r: audio_map.get((str(r.get('dataset', 'Unknown')), str(r.get('patient_id', ''))), ""), - axis=1 + df_master["audio_path"] = df_master.apply( + lambda r: audio_map.get( + (str(r.get("dataset", "Unknown")), str(r.get("patient_id", ""))), "" + ), + axis=1, ) # Save the final CSV for the new PyHealth Engine to pick up automatically df_master.to_csv(output_path, index=False) - logger.info(f"Saved {len(df_master)} QA pairs with mapped audio to {output_path}") + logger.info( + f"Saved {len(df_master)} QA pairs with mapped audio to {output_path}" + ) def _resolve_audio_paths(self, root: str) -> Dict[tuple, str]: """Maps .wav files using robust stem and prefix mapping.""" audio_map = {} wav_files = list(Path(root).rglob("*.wav")) - + if not wav_files: logger.warning(f"No .wav files found in {root}.") return audio_map @@ -150,58 +168,65 @@ def _resolve_audio_paths(self, root: str) -> Dict[tuple, str]: for path in wav_files: stem = path.stem path_str = str(path).lower() - + source = "Unknown" - if "icbhi" in path_str: source = "ICBHI" - elif "circor" in path_str: source = "CirCor" - elif "kauh" in path_str: source = "KAUH" - elif "spr" in path_str: source = "SPRSound" - elif "zch" in path_str: source = "ZCHSound" - + if "icbhi" in path_str: + source = "ICBHI" + elif "circor" in path_str: + source = "CirCor" + elif "kauh" in path_str: + source = "KAUH" + elif "spr" in path_str: + source = "SPRSound" + elif "zch" in path_str: + source = "ZCHSound" + # 1. Primary Mapping: Exact filename match audio_map[(source, stem)] = str(path.absolute()) - + # 2. Fallback Mapping: Base Patient ID match (e.g., '101' from '101_1b1') - base_id = stem.split('_')[0] + base_id = stem.split("_")[0] if (source, base_id) not in audio_map: audio_map[(source, base_id)] = str(path.absolute()) - + return audio_map def parse_func(self) -> Dict[str, Any]: """Merges tabular QA metadata with local audio paths.""" csv_path = os.path.join(self.root, "caresound_metadata.csv") df = pd.read_csv(csv_path) - + patients = {} missing_sources = set() - + for _, row in df.iterrows(): - pid = str(row['patient_id']) - source = str(row['dataset']) - + pid = str(row["patient_id"]) + source = str(row["dataset"]) + audio_path = self.audio_path_map.get((source, pid)) - + if not audio_path: missing_sources.add(source) continue - + if pid not in patients: patients[pid] = {"patient_id": pid, "visits": {}} - + visit_id = f"{source}_{pid}" if visit_id not in patients[pid]["visits"]: patients[pid]["visits"][visit_id] = { "visit_id": visit_id, "audio_path": audio_path, - "events": [] + "events": [], } - - patients[pid]["visits"][visit_id]["events"].append({ - "question": row.get('question', ''), - "answer": row.get('answer', ''), - "hf_split": row.get('hf_split', 'unknown') - }) + + patients[pid]["visits"][visit_id]["events"].append( + { + "question": row.get("question", ""), + "answer": row.get("answer", ""), + "hf_split": row.get("hf_split", "unknown"), + } + ) if missing_sources: logger.warning( @@ -214,4 +239,5 @@ def parse_func(self) -> Dict[str, Any]: @property def default_task(self): from pyhealth.tasks import CaReSoundAQA - return CaReSoundAQA() \ No newline at end of file + + return CaReSoundAQA() diff --git a/pyhealth/tasks/caresound_tasks.py b/pyhealth/tasks/caresound_tasks.py index f0d24f73e..d6e734d54 100644 --- a/pyhealth/tasks/caresound_tasks.py +++ b/pyhealth/tasks/caresound_tasks.py @@ -1,6 +1,6 @@ """Audio Question Answering tasks for PyHealth. -This module provides tasks for processing generative text answers +This module provides tasks for processing generative text answers based on medical audio signals using the CaReSound dataset. """ @@ -16,6 +16,7 @@ class CaReSoundAQA(BaseTask): input_schema (Dict[str, str]): Required inputs (audio_path, question). output_schema (Dict[str, str]): Required outputs (answer). """ + task_name: str = "CaReSoundAQA" input_schema: Dict[str, str] = { "question": "text", @@ -32,26 +33,28 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient record to extract audio-QA samples.""" samples: List[Dict[str, Any]] = [] events = patient.get_events("metadata") - + for event in events: # The new engine puts CSV columns into attr_dict attr = getattr(event, "attr_dict", {}) - + audio_path = str(attr.get("audio_path", "")) question = str(attr.get("question", "")) answer = str(attr.get("answer", "")) hf_split = str(attr.get("hf_split", "unknown")) - + if not audio_path or not question or not answer: continue - samples.append({ - "patient_id": patient.patient_id, - "visit_id": f"v_{patient.patient_id}", - "audio_path": audio_path, - "question": question, - "answer": answer, - "original_hf_split": hf_split, - }) - - return samples \ No newline at end of file + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": f"v_{patient.patient_id}", + "audio_path": audio_path, + "question": question, + "answer": answer, + "original_hf_split": hf_split, + } + ) + + return samples From c261c5affe4acfce7c4354821fd69f211a866ffd Mon Sep 17 00:00:00 2001 From: Rahul D Date: Tue, 21 Apr 2026 11:55:52 -0400 Subject: [PATCH 07/10] Added example usage --- pyhealth/datasets/caresound.py | 5 +++++ pyhealth/tasks/caresound_tasks.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/pyhealth/datasets/caresound.py b/pyhealth/datasets/caresound.py index 8c7abc6f6..e04021230 100644 --- a/pyhealth/datasets/caresound.py +++ b/pyhealth/datasets/caresound.py @@ -27,6 +27,11 @@ class CaReSoundDataset(BaseDataset): tables: Optional list of tables to load. Defaults to ["metadata"]. dataset_name: Optional name of the dataset. Defaults to "caresound". config_path: Optional path to the configuration file. + + Example: + >>> from pyhealth.datasets import CaReSoundDataset + >>> example_dataset = CaReSoundDataset(root="/Users/rahuld/Downloads/CaReAQA/datasets") + >>> example_dataset.stats() """ def __init__( diff --git a/pyhealth/tasks/caresound_tasks.py b/pyhealth/tasks/caresound_tasks.py index d6e734d54..80666901a 100644 --- a/pyhealth/tasks/caresound_tasks.py +++ b/pyhealth/tasks/caresound_tasks.py @@ -15,6 +15,12 @@ class CaReSoundAQA(BaseTask): task_name (str): The name of the task. input_schema (Dict[str, str]): Required inputs (audio_path, question). output_schema (Dict[str, str]): Required outputs (answer). + + Example: + >>> from pyhealth.datasets import CaReSoundDataset + >>> from pyhealth.tasks import CaReSoundAQA + >>> example_dataset = CaReSoundDataset(root="/Users/rahuld/Downloads/CaReAQA/datasets") + >>> sample_dataset = example_dataset.set_task(CaReSoundAQA()) """ task_name: str = "CaReSoundAQA" From e726b24dad57d0bf9db325631ec1ebfb6ee7f4da Mon Sep 17 00:00:00 2001 From: Rahul D Date: Tue, 21 Apr 2026 12:06:12 -0400 Subject: [PATCH 08/10] Added synthetic data test --- tests/core/test_caresound.py | 92 +++++++++++++++--------------------- 1 file changed, 37 insertions(+), 55 deletions(-) diff --git a/tests/core/test_caresound.py b/tests/core/test_caresound.py index aa125d7fb..23d21e359 100644 --- a/tests/core/test_caresound.py +++ b/tests/core/test_caresound.py @@ -1,87 +1,69 @@ -import os import unittest +import tempfile +import pandas as pd from pathlib import Path - from pyhealth.datasets import CaReSoundDataset from pyhealth.tasks import CaReSoundAQA class TestCaReSoundDataset(unittest.TestCase): - """Test cases for the CaReSoundDataset.""" - - @classmethod - def setUpClass(cls): - """Set up test resources path pointing to the PyHealthCS598/test-resources folder.""" - # This navigates up from tests/datasets/test_caresound.py to the project root - cls.test_resources = Path(__file__).parent.parent.parent / "test-resources" / "caresound" / "datasets" - - # Ensure the directory actually exists to prevent confusing errors - if not cls.test_resources.exists(): - raise FileNotFoundError( - f"Test resources not found at {cls.test_resources}. " - "Please ensure your sample audio and CaReSoundQA.csv are placed there." - ) + """Test cases for the CaReSoundDataset using synthetic data.""" + + def setUp(self): + """Create a temporary directory with synthetic data.""" + self.test_dir = tempfile.TemporaryDirectory() + self.root = Path(self.test_dir.name) + + # 1. Create synthetic CSV (2 patients) + self.df = pd.DataFrame( + { + "patient_id": ["101", "102"], + "dataset": ["icbhi", "circor"], + "question": ["Is this normal?", "Any abnormalities?"], + "answer": ["Normal", "Abnormal"], + "hf_split": ["train", "test"], + "metadata/audio_path": ["icbhi_101.wav", "circor_102.wav"], + } + ) + self.df.to_csv(self.root / "caresound_metadata.csv", index=False) + + # 2. Create dummy audio files (just empty files are enough for path matching) + (self.root / "icbhi_101.wav").touch() + (self.root / "circor_102.wav").touch() + + def tearDown(self): + """Cleanup temporary directory.""" + self.test_dir.cleanup() def test_dataset_initialization(self): - """Test that the dataset initializes correctly from the test-resources folder.""" - dataset = CaReSoundDataset(root=str(self.test_resources)) + dataset = CaReSoundDataset(root=str(self.root)) self.assertIsNotNone(dataset) self.assertEqual(dataset.dataset_name, "caresound") - - def test_stats(self): - """Test that stats() runs without error.""" - dataset = CaReSoundDataset(root=str(self.test_resources)) - import sys, io - captured_output = io.StringIO() - sys.stdout = captured_output - dataset.stats() - sys.stdout = sys.__stdout__ - - # Updated to match the actual PyHealth output format! - self.assertIn("Dataset: caresound", captured_output.getvalue()) def test_default_task(self): - """Test that the default task is properly assigned to CaReSoundAQA.""" - dataset = CaReSoundDataset(root=str(self.test_resources)) + dataset = CaReSoundDataset(root=str(self.root)) self.assertIsInstance(dataset.default_task, CaReSoundAQA) def test_set_task(self): - """Test applying the CaReSoundAQA task to the dataset.""" - dataset = CaReSoundDataset(root=str(self.test_resources)) + dataset = CaReSoundDataset(root=str(self.root)) task = CaReSoundAQA() samples = dataset.set_task(task) - # Ensure the task actually generated samples - self.assertGreater(len(samples), 0) - - # Verify the schema of the first sample - sample = samples[0] - self.assertIn("patient_id", sample) - self.assertIn("question", sample) - self.assertIn("answer", sample) - self.assertIn("audio_path", sample) + # We expect 2 samples since we created 2 patients + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["patient_id"], "101") class TestCaReSoundAQA(unittest.TestCase): - """Test cases for the CaReSoundAQA task schema and utilities.""" + """Test cases for the CaReSoundAQA task schema.""" def setUp(self): self.task = CaReSoundAQA() def test_task_attributes(self): - """Test task class attributes.""" self.assertEqual(self.task.task_name, "CaReSoundAQA") self.assertIn("question", self.task.input_schema) - self.assertEqual(self.task.input_schema["question"], "text") - self.assertIn("answer", self.task.output_schema) - self.assertEqual(self.task.output_schema["answer"], "text") - - def test_safe_str(self): - """Test the string safety utility.""" - self.assertEqual(self.task._safe_str("Hello"), "Hello") - self.assertEqual(self.task._safe_str(None, default="N/A"), "N/A") - self.assertEqual(self.task._safe_str("nan", default="missing"), "missing") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From de25997a1caa9ecd445d24353b7a19f564438fd8 Mon Sep 17 00:00:00 2001 From: Rahul D Date: Tue, 21 Apr 2026 12:07:24 -0400 Subject: [PATCH 09/10] removed resources dataset example --- .../caresound/datasets/CaReSoundQA_train.csv | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 test-resources/caresound/datasets/CaReSoundQA_train.csv diff --git a/test-resources/caresound/datasets/CaReSoundQA_train.csv b/test-resources/caresound/datasets/CaReSoundQA_train.csv deleted file mode 100644 index 974d3c46a..000000000 --- a/test-resources/caresound/datasets/CaReSoundQA_train.csv +++ /dev/null @@ -1,13 +0,0 @@ -patient_id,question,answer,dataset -65109516,Were any abnormal lung sounds noted during auscultation?,"No, the lungs were normal during auscultation.",SPRSound -ZCH0810,What specific type of defect is indicated in the diagnosis?,A ventricular septal defect is indicated in the diagnosis.,ZCHSound -147,What is the diagnosis based on the auscultation findings?,COPD,ICBHI -159,Are crackles present in the anterior right chest location?,"No, crackles are not present in the anterior right chest location.",ICBHI -85172,Is the murmur heard more prominently at any particular valve area?,"Yes, the murmur is most audible at the pulmonic valve area.",CirCor -BP50,Where was the normal respiratory sound heard?,Posterior Right Lower,KAUH -DP83,Where was the sound located during auscultation?,Anterior Right Upper,KAUH -EP31,Where is the location of the auscultation?,Posterior Lower Middle,KAUH -ZCH1062,Is there any abnormality detected in the cardiac auscultation findings?,"No, the cardiac auscultation findings are normal.",ZCHSound -154,Is there evidence of wheezing in the patient's auscultation?,"No, there is no evidence of wheezing.",ICBHI -ZCH0125,Are there any abnormalities in the heart sounds?,"No, the heart sounds are normal.",ZCHSound -203,What is the diagnosis based on auscultation?,COPD,ICBHI \ No newline at end of file From 2cc1a0fe7a113f522058d0e3ffc9309e6125db53 Mon Sep 17 00:00:00 2001 From: Rahul D Date: Tue, 21 Apr 2026 13:59:33 -0400 Subject: [PATCH 10/10] Added example usage and Ablation attempt --- .../CaReSoundDataset_CaReSoundAQA_CaReAQA.py | 489 ++++++++++++++++++ 1 file changed, 489 insertions(+) create mode 100644 examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py diff --git a/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py b/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py new file mode 100644 index 000000000..889c45858 --- /dev/null +++ b/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py @@ -0,0 +1,489 @@ +from pyhealth.datasets import CaReSoundDataset +from pyhealth.tasks import CaReSoundAQA +import librosa +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm +import numpy as np +from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig +from peft import get_peft_model, LoraConfig + +# ============================================================================== +# MODEL SETUP INSTRUCTIONS +# ============================================================================== +# 1. DOWNLOAD: Obtain the 'CaReAQAmodel.pt' weights file from the official +# project source (e.g., Hugging Face or the provided Google Drive link). +# +# 2. DIRECTORY STRUCTURE: Create the following folder path on your Mac: +# /Users/rahuld/Downloads/CaReAQA/CaReAQAModel/ +# +# 3. PLACEMENT: Move the downloaded file into that folder. +# Ensure the file is named EXACTLY 'CaReAQAmodel.pt'. +# +# 4. VERIFICATION: Your final file path must match the variable below: +# local_careqa_path = "/Users/rahuld/Downloads/CaReAQA/CaReAQAModel/CaReAQAmodel.pt" +# ============================================================================== + +# ============================================================================== +# DATASET SETUP INSTRUCTIONS +# ============================================================================== +# 1. DOWNLOAD: Obtain the 5 source audio datasets: ICBHI, KAUH, CirCor, +# SPRSound, and ZCHSound from their respective open-access repositories. +# +# 2. DIRECTORY STRUCTURE: Maintain the following folder path on your Mac: +# /Users/rahuld/Downloads/CaReAQA/datasets/ +# +# 3. PLACEMENT: Ensure audio files (.wav) are located within their respective +# folders (e.g., 'ICBHI Respiratory Sound Dataset', 'ZCHSound', etc.). +# The mapping logic will scan these directories to link audio to QA pairs. +# +# 4. VERIFICATION: Your current 'ls' output confirms the following layout: +# /Users/rahuld/Downloads/CaReAQA/datasets/ +# ├── CirCor Pediatric Heart Sound Dataset/ +# ├── ICBHI Respiratory Sound Dataset/ +# ├── KAUH Respiratory Dataset/ +# ├── SPRSound Pediatric Respiratory Dataset/ +# ├── ZCHSound/ +# └── caresound_metadata.csv +# ============================================================================== + +local_careqa_path = "/Users/rahuld/Downloads/CaReAQA/CaReAQAModel/CaReAQAmodel.pt" +path_to_dataset = "/Users/rahuld/Downloads/CaReAQA/datasets" +local_llama_dir = "/Users/rahuld/Downloads/meta-llama/Llama-3.2-3B" + +# set variable to true to use model to generate answer +generateModelAnswer = False + + +# ========================================== +# 1. MODEL CODE START AI GENERATED BASED ON .pt file +# ========================================== + + +# ========================================== +# 1. CLIPCAP TRANSFORMER MAPPER (PREFIX PROJECTOR) +# ========================================== +class Mlp(nn.Module): + def __init__(self, in_dim, h_dim, out_d=None, act=nn.GELU(), dropout=0.0): + super().__init__() + out_d = out_d if out_d is not None else in_dim + self.fc1 = nn.Linear(in_dim, h_dim) + self.act = act + self.fc2 = nn.Linear(h_dim, out_d) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.dropout(self.fc2(self.dropout(self.act(self.fc1(x))))) + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim_self // num_heads + self.scale = head_dim**-0.5 + self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) + self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) + self.project = nn.Linear(dim_self, dim_self) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, y=None, mask=None): + y = y if y is not None else x + b, n, c = x.shape + _, m, d = y.shape + queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) + kv = self.to_keys_values(y).reshape( + b, m, 2, self.num_heads, c // self.num_heads + ) + keys, values = kv[:, :, 0], kv[:, :, 1] + attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale + if mask is not None: + if mask.dim() == 2: + mask = mask.unsqueeze(1) + attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) + attention = attention.softmax(dim=2) + out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c) + return self.project(out), attention + + +class TransformerLayer(nn.Module): + def __init__( + self, + dim_self, + dim_ref, + num_heads, + mlp_ratio=4.0, + bias=False, + dropout=0.0, + act=nn.GELU(), + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim_self) + self.attn = MultiHeadAttention( + dim_self, dim_ref, num_heads, bias=bias, dropout=dropout + ) + self.norm2 = norm_layer(dim_self) + self.mlp = Mlp(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) + + def forward(self, x, y=None, mask=None): + x_, _ = self.attn(self.norm1(x), y, mask) + x = x + x_ + x = x + self.mlp(self.norm2(x)) + return x + + +class TransformerMapper(nn.Module): + def __init__( + self, dim_clip, dim_embedding, prefix_length, clip_length, num_layers=8 + ): + super().__init__() + self.clip_length = clip_length + self.transformer = nn.ModuleList( + [ + TransformerLayer( + dim_embedding, + dim_embedding, + 8, + 2.0, + act=nn.GELU(), + norm_layer=nn.LayerNorm, + ) + for _ in range(num_layers) + ] + ) + self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) + self.prefix_const = nn.Parameter( + torch.randn(prefix_length, dim_embedding), requires_grad=True + ) + + def forward(self, x): + x = self.linear(x).view(x.shape[0], self.clip_length, -1) + prefix = self.prefix_const.unsqueeze(0).expand( + x.shape[0], *self.prefix_const.shape + ) + prefix = torch.cat((x, prefix), dim=1) + for layer in self.transformer: + prefix = layer(prefix) + return prefix[:, self.clip_length :] + + +# ========================================== +# 2. AUDIO ENCODER (EFFICIENTNET) +# ========================================== +class AudioEncoder(nn.Module): + def __init__(self): + super().__init__() + self.cnn1 = nn.Conv2d(1, 3, kernel_size=3, padding=1) + # Using timm to automatically build the EfficientNet architecture matching the state_dict + self.efficientnet = timm.create_model( + "efficientnet_b0", pretrained=False, num_classes=0 + ) + + def forward(self, x): + # Format the 3D spectrogram into a 4D batch for the CNN + if x.dim() == 3: + x = x.unsqueeze(1) + x = self.cnn1(x) + x = self.efficientnet.forward_features(x) + # Global Average Pool to turn feature map into a 1280-dim vector + x = x.mean(dim=[2, 3]) + return x + + +class AudioModelWrapper(nn.Module): + def __init__(self): + super().__init__() + self.encoder = AudioEncoder() + + def extract_feature(self, x, dim=1280): + return self.encoder(x) + + +# ========================================== +# 3. MAIN MODEL WRAPPER +# ========================================== +class AudioQAModel(nn.Module): + def __init__( + self, + llm_type, + opera_checkpoint_path=None, + prefix_length=8, + clip_length=1, + setting="lora", + mapping_type="Transformer", + fine_tune_opera=True, + args=None, + ): + super().__init__() + + # Load the base Llama model + print(f"Loading Base LLM: {llm_type}...") + self.llm = AutoModelForCausalLM.from_pretrained( + llm_type, torch_dtype=torch.float16 + ) + + # Hook up LoRA adapters matching the model.pt state_dict shapes + if setting == "lora": + lora_config = LoraConfig( + r=8, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "up_proj", + "down_proj", + "gate_proj", + ], + bias="none", + task_type="CAUSAL_LM", + ) + self.llm = get_peft_model(self.llm, lora_config) + + # Hook up the rebuilt Audio and Projection modules + self.audio_model = AudioModelWrapper() + + dim_embedding = ( + self.llm.config.hidden_size if hasattr(self.llm, "config") else 3072 + ) + self.prefix_project = TransformerMapper( + dim_clip=1280, + dim_embedding=dim_embedding, + prefix_length=prefix_length, + clip_length=clip_length, + num_layers=8, + ) + + +# ========================================== +# 1. MODEL CODE END +# ========================================== + + +# ========================================== +# 1. LOAD MODEL (MODIFIED FOR LOCAL FILE & MAC) +# ========================================== +def load_careqa_model_local( + local_model_path, llm_type="meta-llama/Llama-3.2-3B", prefix_length=8 +): + print("Initializing model architecture...") + model = AudioQAModel( + llm_type=llm_type, + opera_checkpoint_path=None, + prefix_length=prefix_length, + clip_length=1, + setting="lora", + mapping_type="Transformer", + fine_tune_opera=True, + args=None, + ).eval() + + # Automatically detect Apple Silicon (MPS), GPU, or CPU + if False: + device = torch.device("mps") + print("Using Apple Silicon (MPS) for acceleration!") + elif False: + device = torch.device("cuda") + else: + device = torch.device("cpu") + + model = model.to(device) + + print(f"Loading weights from {local_model_path}...") + state_dict = torch.load(local_model_path, map_location="cpu") + + # Extract nested state_dict if it exists + if isinstance(state_dict, dict) and "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + model.load_state_dict(state_dict, strict=False) + print("Model loaded successfully!\n") + return model, device + + +# ========================================== +# 2. PREPROCESS AUDIO +# ========================================== +def preprocess_audio(audio_path, device, sr=16000): + print(f"Processing audio file: {audio_path}") + raw_audio, sr = librosa.load(audio_path, sr=sr) + mel_spec = librosa.feature.melspectrogram( + y=raw_audio, sr=sr, n_fft=1024, hop_length=512, n_mels=64 + ) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + + # Convert to tensor and move to correct device (MPS/CUDA/CPU) + audio_tensor = ( + torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).to(device) + ) + return audio_tensor + + +# ========================================== +# 3. GENERATE ANSWER +# ========================================== +def generate_answer( + model, + tokenizer, + audio_tensor, + question, + device, + prefix_length=8, + audio_feature_dim=1280, +): + # 1. Extract audio features + with torch.no_grad(): + audio_features = model.audio_model.extract_feature( + audio_tensor, dim=audio_feature_dim + ) + projected_prefix = model.prefix_project(audio_features) + + # 2. Tokenize the text prompts + q_prefix = tokenizer.encode("question: ", add_special_tokens=False) + q_tokens = tokenizer.encode(question, add_special_tokens=False) + a_prefix = tokenizer.encode(" answer", add_special_tokens=False) + + input_tokens = q_prefix + q_tokens + a_prefix + + # 3. Create input IDs + input_ids = torch.tensor( + [input_tokens + [tokenizer.eos_token_id] * prefix_length], dtype=torch.long + ).to(device) + attention_mask = torch.ones_like(input_ids) + + # 4. Insert the audio projection (FIXED FOR APPLE SILICON) + # Adding .clone() prevents the Apple MPS silent deadlock! + input_embeds = model.llm.get_input_embeddings()(input_ids).clone() + input_embeds[ + 0, len(q_prefix + q_tokens) : len(q_prefix + q_tokens) + prefix_length + ] = projected_prefix[0] + + # 5. Generate the response + print(">>> Firing up the LLaMA generation engine... (this should take < 2 mins)") + with torch.no_grad(): + output_ids = model.llm.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + max_new_tokens=20, + do_sample=False, + use_cache=False, # FIXED: KV Cache can freeze when using custom inputs_embeds on Mac + pad_token_id=tokenizer.eos_token_id, # Silences the warning you got earlier + ) + print(">>> Generation complete!") + + answer = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() + return answer + + +def get_tokenizer(model_dir: str): + """Loads the LLaMA tokenizer from a local directory.""" + print("Loading local LLaMA Tokenizer...") + return AutoTokenizer.from_pretrained(model_dir) + + +def get_base_llm(model_dir: str): + """Loads the base LLaMA model with 4-bit quantization.""" + from transformers import AutoModelForCausalLM, QuantoConfig + + print("Loading Llama 3.2 3B in 4-bit mode...") + quant_config = QuantoConfig(weights="int8") + llm_model = AutoModelForCausalLM.from_pretrained( + model_dir, + quantization_config=quant_config, + # device_map="mps" # Uncomment to force Apple Silicon GPU allocation if supported + ) + return llm_model + + +def get_careqa_model(careqa_path: str, llama_dir: str): + """Loads the custom AudioQAModel and assigns it to the correct device.""" + # This calls your previously defined load_careqa_model_local method + model, device = load_careqa_model_local( + local_model_path=careqa_path, llm_type=llama_dir + ) + return model, device + + +# ========================================== +# 4. GET ANSWER WRAPPER (Add this above __main__) +# ========================================== +def get_answer(question: str, audio_path: str, model, tokenizer, device) -> str: + print("\n" + "-" * 50) + print(f"Question: {question}") + print("-" * 50) + + audio_tensor = preprocess_audio(audio_path, device) + + print("Generating answer...") + answer = generate_answer(model, tokenizer, audio_tensor, question, device) + + print("-" * 50) + print(f"Answer: {answer}") + print("-" * 50) + + return answer + + +def print_stats(example_dataset: CaReSoundDataset): + print("\n" + "=" * 50) + print("DATASET STATS") + print("=" * 50) + example_dataset.stats() + + +def print_sample_i(sample_dataset: CaReSoundAQA, i: int): + print("\n" + "=" * 50) + print(f"Sample at index {i}") + print("=" * 50) + print(sample_dataset[i]) + print(sample_dataset[i]["audio_path"]) + + +def get_raw_audio(sample_dataset: CaReSoundAQA, i, sr=16000): + audio_path = sample_dataset[i]["audio_path"] + + # 2. Verify the file exists before trying to load it + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Could not find audio file at: {audio_path}") + + # 3. Load the raw audio + # sr=16000 ensures it matches the sample rate used in medical audio tasks + audio, sample_rate = librosa.load(audio_path, sr=sr) + + return audio, sample_rate + + +if __name__ == "__main__": + + example_dataset = CaReSoundDataset(root=path_to_dataset) + sample_dataset = example_dataset.set_task(CaReSoundAQA()) + # testdataset() + + # test dataset + print_stats(example_dataset) + print_sample_i(sample_dataset, 10) + get_raw_audio(sample_dataset, 10) + + if generateModelAnswer: + # 2. Extract Data for Inference + test_question = sample_dataset[10]["question"] + target_audio_path = sample_dataset[10]["audio_path"] # Renamed for clarity + ground_truth = sample_dataset[10]["answer"] + + # 3. Load Models + tokenizer = get_tokenizer(local_llama_dir) + careqa_model, device = get_careqa_model(local_careqa_path, local_llama_dir) + + # 4. Run Inference (Fixed the variable name mismatch here) + generated_answer = get_answer( + question=test_question, + audio_path=target_audio_path, + model=careqa_model, + tokenizer=tokenizer, + device=device, + ) + + print(f"\nGround Truth was: {ground_truth}")