diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..fd543ad9c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -236,6 +236,7 @@ Available Datasets datasets/pyhealth.datasets.EHRShotDataset datasets/pyhealth.datasets.Support2Dataset datasets/pyhealth.datasets.BMDHSDataset + datasets/pyhealth.datasets.CaReSoundDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset datasets/pyhealth.datasets.PhysioNetDeIDDataset diff --git a/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst b/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst new file mode 100644 index 000000000..821842b69 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.CaReSoundDataset +================================ + +The CaReSound dataset provides question and answer pairs for medical sounds. For more information see `CaReSound `_. This dataset was contributed as part of the CaReAQA: A Cardiac and Respiratory Audio Question Answering Model for Open-Ended Diagnostic Reasoning work (`arXiv:2505.01199 `_). + +.. autoclass:: pyhealth.datasets.CaReSoundDataset + :members: + :undoc-members: + :show-inheritance: + + \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..44d3adad0 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + CaReSound diff --git a/docs/api/tasks/pyhealth.tasks.CaReSoundAQA b/docs/api/tasks/pyhealth.tasks.CaReSoundAQA new file mode 100644 index 000000000..f4b4416b0 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.CaReSoundAQA @@ -0,0 +1,7 @@ +pyhealth.tasks.CaReSoundAQA +============================================== + +.. autoclass:: pyhealth.tasks.CaReSoundAQA + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py b/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py new file mode 100644 index 000000000..889c45858 --- /dev/null +++ b/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py @@ -0,0 +1,489 @@ +from pyhealth.datasets import CaReSoundDataset +from pyhealth.tasks import CaReSoundAQA +import librosa +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm +import numpy as np +from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig +from peft import get_peft_model, LoraConfig + +# ============================================================================== +# MODEL SETUP INSTRUCTIONS +# ============================================================================== +# 1. DOWNLOAD: Obtain the 'CaReAQAmodel.pt' weights file from the official +# project source (e.g., Hugging Face or the provided Google Drive link). +# +# 2. DIRECTORY STRUCTURE: Create the following folder path on your Mac: +# /Users/rahuld/Downloads/CaReAQA/CaReAQAModel/ +# +# 3. PLACEMENT: Move the downloaded file into that folder. +# Ensure the file is named EXACTLY 'CaReAQAmodel.pt'. +# +# 4. VERIFICATION: Your final file path must match the variable below: +# local_careqa_path = "/Users/rahuld/Downloads/CaReAQA/CaReAQAModel/CaReAQAmodel.pt" +# ============================================================================== + +# ============================================================================== +# DATASET SETUP INSTRUCTIONS +# ============================================================================== +# 1. DOWNLOAD: Obtain the 5 source audio datasets: ICBHI, KAUH, CirCor, +# SPRSound, and ZCHSound from their respective open-access repositories. +# +# 2. DIRECTORY STRUCTURE: Maintain the following folder path on your Mac: +# /Users/rahuld/Downloads/CaReAQA/datasets/ +# +# 3. PLACEMENT: Ensure audio files (.wav) are located within their respective +# folders (e.g., 'ICBHI Respiratory Sound Dataset', 'ZCHSound', etc.). +# The mapping logic will scan these directories to link audio to QA pairs. +# +# 4. VERIFICATION: Your current 'ls' output confirms the following layout: +# /Users/rahuld/Downloads/CaReAQA/datasets/ +# ├── CirCor Pediatric Heart Sound Dataset/ +# ├── ICBHI Respiratory Sound Dataset/ +# ├── KAUH Respiratory Dataset/ +# ├── SPRSound Pediatric Respiratory Dataset/ +# ├── ZCHSound/ +# └── caresound_metadata.csv +# ============================================================================== + +local_careqa_path = "/Users/rahuld/Downloads/CaReAQA/CaReAQAModel/CaReAQAmodel.pt" +path_to_dataset = "/Users/rahuld/Downloads/CaReAQA/datasets" +local_llama_dir = "/Users/rahuld/Downloads/meta-llama/Llama-3.2-3B" + +# set variable to true to use model to generate answer +generateModelAnswer = False + + +# ========================================== +# 1. MODEL CODE START AI GENERATED BASED ON .pt file +# ========================================== + + +# ========================================== +# 1. CLIPCAP TRANSFORMER MAPPER (PREFIX PROJECTOR) +# ========================================== +class Mlp(nn.Module): + def __init__(self, in_dim, h_dim, out_d=None, act=nn.GELU(), dropout=0.0): + super().__init__() + out_d = out_d if out_d is not None else in_dim + self.fc1 = nn.Linear(in_dim, h_dim) + self.act = act + self.fc2 = nn.Linear(h_dim, out_d) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.dropout(self.fc2(self.dropout(self.act(self.fc1(x))))) + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim_self // num_heads + self.scale = head_dim**-0.5 + self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) + self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) + self.project = nn.Linear(dim_self, dim_self) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, y=None, mask=None): + y = y if y is not None else x + b, n, c = x.shape + _, m, d = y.shape + queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) + kv = self.to_keys_values(y).reshape( + b, m, 2, self.num_heads, c // self.num_heads + ) + keys, values = kv[:, :, 0], kv[:, :, 1] + attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale + if mask is not None: + if mask.dim() == 2: + mask = mask.unsqueeze(1) + attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) + attention = attention.softmax(dim=2) + out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c) + return self.project(out), attention + + +class TransformerLayer(nn.Module): + def __init__( + self, + dim_self, + dim_ref, + num_heads, + mlp_ratio=4.0, + bias=False, + dropout=0.0, + act=nn.GELU(), + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim_self) + self.attn = MultiHeadAttention( + dim_self, dim_ref, num_heads, bias=bias, dropout=dropout + ) + self.norm2 = norm_layer(dim_self) + self.mlp = Mlp(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) + + def forward(self, x, y=None, mask=None): + x_, _ = self.attn(self.norm1(x), y, mask) + x = x + x_ + x = x + self.mlp(self.norm2(x)) + return x + + +class TransformerMapper(nn.Module): + def __init__( + self, dim_clip, dim_embedding, prefix_length, clip_length, num_layers=8 + ): + super().__init__() + self.clip_length = clip_length + self.transformer = nn.ModuleList( + [ + TransformerLayer( + dim_embedding, + dim_embedding, + 8, + 2.0, + act=nn.GELU(), + norm_layer=nn.LayerNorm, + ) + for _ in range(num_layers) + ] + ) + self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) + self.prefix_const = nn.Parameter( + torch.randn(prefix_length, dim_embedding), requires_grad=True + ) + + def forward(self, x): + x = self.linear(x).view(x.shape[0], self.clip_length, -1) + prefix = self.prefix_const.unsqueeze(0).expand( + x.shape[0], *self.prefix_const.shape + ) + prefix = torch.cat((x, prefix), dim=1) + for layer in self.transformer: + prefix = layer(prefix) + return prefix[:, self.clip_length :] + + +# ========================================== +# 2. AUDIO ENCODER (EFFICIENTNET) +# ========================================== +class AudioEncoder(nn.Module): + def __init__(self): + super().__init__() + self.cnn1 = nn.Conv2d(1, 3, kernel_size=3, padding=1) + # Using timm to automatically build the EfficientNet architecture matching the state_dict + self.efficientnet = timm.create_model( + "efficientnet_b0", pretrained=False, num_classes=0 + ) + + def forward(self, x): + # Format the 3D spectrogram into a 4D batch for the CNN + if x.dim() == 3: + x = x.unsqueeze(1) + x = self.cnn1(x) + x = self.efficientnet.forward_features(x) + # Global Average Pool to turn feature map into a 1280-dim vector + x = x.mean(dim=[2, 3]) + return x + + +class AudioModelWrapper(nn.Module): + def __init__(self): + super().__init__() + self.encoder = AudioEncoder() + + def extract_feature(self, x, dim=1280): + return self.encoder(x) + + +# ========================================== +# 3. MAIN MODEL WRAPPER +# ========================================== +class AudioQAModel(nn.Module): + def __init__( + self, + llm_type, + opera_checkpoint_path=None, + prefix_length=8, + clip_length=1, + setting="lora", + mapping_type="Transformer", + fine_tune_opera=True, + args=None, + ): + super().__init__() + + # Load the base Llama model + print(f"Loading Base LLM: {llm_type}...") + self.llm = AutoModelForCausalLM.from_pretrained( + llm_type, torch_dtype=torch.float16 + ) + + # Hook up LoRA adapters matching the model.pt state_dict shapes + if setting == "lora": + lora_config = LoraConfig( + r=8, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "up_proj", + "down_proj", + "gate_proj", + ], + bias="none", + task_type="CAUSAL_LM", + ) + self.llm = get_peft_model(self.llm, lora_config) + + # Hook up the rebuilt Audio and Projection modules + self.audio_model = AudioModelWrapper() + + dim_embedding = ( + self.llm.config.hidden_size if hasattr(self.llm, "config") else 3072 + ) + self.prefix_project = TransformerMapper( + dim_clip=1280, + dim_embedding=dim_embedding, + prefix_length=prefix_length, + clip_length=clip_length, + num_layers=8, + ) + + +# ========================================== +# 1. MODEL CODE END +# ========================================== + + +# ========================================== +# 1. LOAD MODEL (MODIFIED FOR LOCAL FILE & MAC) +# ========================================== +def load_careqa_model_local( + local_model_path, llm_type="meta-llama/Llama-3.2-3B", prefix_length=8 +): + print("Initializing model architecture...") + model = AudioQAModel( + llm_type=llm_type, + opera_checkpoint_path=None, + prefix_length=prefix_length, + clip_length=1, + setting="lora", + mapping_type="Transformer", + fine_tune_opera=True, + args=None, + ).eval() + + # Automatically detect Apple Silicon (MPS), GPU, or CPU + if False: + device = torch.device("mps") + print("Using Apple Silicon (MPS) for acceleration!") + elif False: + device = torch.device("cuda") + else: + device = torch.device("cpu") + + model = model.to(device) + + print(f"Loading weights from {local_model_path}...") + state_dict = torch.load(local_model_path, map_location="cpu") + + # Extract nested state_dict if it exists + if isinstance(state_dict, dict) and "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + model.load_state_dict(state_dict, strict=False) + print("Model loaded successfully!\n") + return model, device + + +# ========================================== +# 2. PREPROCESS AUDIO +# ========================================== +def preprocess_audio(audio_path, device, sr=16000): + print(f"Processing audio file: {audio_path}") + raw_audio, sr = librosa.load(audio_path, sr=sr) + mel_spec = librosa.feature.melspectrogram( + y=raw_audio, sr=sr, n_fft=1024, hop_length=512, n_mels=64 + ) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + + # Convert to tensor and move to correct device (MPS/CUDA/CPU) + audio_tensor = ( + torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).to(device) + ) + return audio_tensor + + +# ========================================== +# 3. GENERATE ANSWER +# ========================================== +def generate_answer( + model, + tokenizer, + audio_tensor, + question, + device, + prefix_length=8, + audio_feature_dim=1280, +): + # 1. Extract audio features + with torch.no_grad(): + audio_features = model.audio_model.extract_feature( + audio_tensor, dim=audio_feature_dim + ) + projected_prefix = model.prefix_project(audio_features) + + # 2. Tokenize the text prompts + q_prefix = tokenizer.encode("question: ", add_special_tokens=False) + q_tokens = tokenizer.encode(question, add_special_tokens=False) + a_prefix = tokenizer.encode(" answer", add_special_tokens=False) + + input_tokens = q_prefix + q_tokens + a_prefix + + # 3. Create input IDs + input_ids = torch.tensor( + [input_tokens + [tokenizer.eos_token_id] * prefix_length], dtype=torch.long + ).to(device) + attention_mask = torch.ones_like(input_ids) + + # 4. Insert the audio projection (FIXED FOR APPLE SILICON) + # Adding .clone() prevents the Apple MPS silent deadlock! + input_embeds = model.llm.get_input_embeddings()(input_ids).clone() + input_embeds[ + 0, len(q_prefix + q_tokens) : len(q_prefix + q_tokens) + prefix_length + ] = projected_prefix[0] + + # 5. Generate the response + print(">>> Firing up the LLaMA generation engine... (this should take < 2 mins)") + with torch.no_grad(): + output_ids = model.llm.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + max_new_tokens=20, + do_sample=False, + use_cache=False, # FIXED: KV Cache can freeze when using custom inputs_embeds on Mac + pad_token_id=tokenizer.eos_token_id, # Silences the warning you got earlier + ) + print(">>> Generation complete!") + + answer = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() + return answer + + +def get_tokenizer(model_dir: str): + """Loads the LLaMA tokenizer from a local directory.""" + print("Loading local LLaMA Tokenizer...") + return AutoTokenizer.from_pretrained(model_dir) + + +def get_base_llm(model_dir: str): + """Loads the base LLaMA model with 4-bit quantization.""" + from transformers import AutoModelForCausalLM, QuantoConfig + + print("Loading Llama 3.2 3B in 4-bit mode...") + quant_config = QuantoConfig(weights="int8") + llm_model = AutoModelForCausalLM.from_pretrained( + model_dir, + quantization_config=quant_config, + # device_map="mps" # Uncomment to force Apple Silicon GPU allocation if supported + ) + return llm_model + + +def get_careqa_model(careqa_path: str, llama_dir: str): + """Loads the custom AudioQAModel and assigns it to the correct device.""" + # This calls your previously defined load_careqa_model_local method + model, device = load_careqa_model_local( + local_model_path=careqa_path, llm_type=llama_dir + ) + return model, device + + +# ========================================== +# 4. GET ANSWER WRAPPER (Add this above __main__) +# ========================================== +def get_answer(question: str, audio_path: str, model, tokenizer, device) -> str: + print("\n" + "-" * 50) + print(f"Question: {question}") + print("-" * 50) + + audio_tensor = preprocess_audio(audio_path, device) + + print("Generating answer...") + answer = generate_answer(model, tokenizer, audio_tensor, question, device) + + print("-" * 50) + print(f"Answer: {answer}") + print("-" * 50) + + return answer + + +def print_stats(example_dataset: CaReSoundDataset): + print("\n" + "=" * 50) + print("DATASET STATS") + print("=" * 50) + example_dataset.stats() + + +def print_sample_i(sample_dataset: CaReSoundAQA, i: int): + print("\n" + "=" * 50) + print(f"Sample at index {i}") + print("=" * 50) + print(sample_dataset[i]) + print(sample_dataset[i]["audio_path"]) + + +def get_raw_audio(sample_dataset: CaReSoundAQA, i, sr=16000): + audio_path = sample_dataset[i]["audio_path"] + + # 2. Verify the file exists before trying to load it + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Could not find audio file at: {audio_path}") + + # 3. Load the raw audio + # sr=16000 ensures it matches the sample rate used in medical audio tasks + audio, sample_rate = librosa.load(audio_path, sr=sr) + + return audio, sample_rate + + +if __name__ == "__main__": + + example_dataset = CaReSoundDataset(root=path_to_dataset) + sample_dataset = example_dataset.set_task(CaReSoundAQA()) + # testdataset() + + # test dataset + print_stats(example_dataset) + print_sample_i(sample_dataset, 10) + get_raw_audio(sample_dataset, 10) + + if generateModelAnswer: + # 2. Extract Data for Inference + test_question = sample_dataset[10]["question"] + target_audio_path = sample_dataset[10]["audio_path"] # Renamed for clarity + ground_truth = sample_dataset[10]["answer"] + + # 3. Load Models + tokenizer = get_tokenizer(local_llama_dir) + careqa_model, device = get_careqa_model(local_careqa_path, local_llama_dir) + + # 4. Run Inference (Fixed the variable name mismatch here) + generated_answer = get_answer( + question=test_question, + audio_path=target_audio_path, + model=careqa_model, + tokenizer=tokenizer, + device=device, + ) + + print(f"\nGround Truth was: {ground_truth}") diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..d2e51ecf6 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -48,6 +48,7 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset +from .caresound import CaReSoundDataset from .chestxray14 import ChestXray14Dataset from .clinvar import ClinVarDataset from .cosmic import COSMICDataset diff --git a/pyhealth/datasets/caresound.py b/pyhealth/datasets/caresound.py new file mode 100644 index 000000000..e04021230 --- /dev/null +++ b/pyhealth/datasets/caresound.py @@ -0,0 +1,248 @@ +"""CaReSound dataset for PyHealth. + +This module provides the CaReSoundDataset class for loading and processing +the CaReSound benchmark data for Audio Question Answering (AQA) tasks. +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional, Dict, Any + +import pandas as pd +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class CaReSoundDataset(BaseDataset): + """CaReSound dataset for open-ended diagnostic reasoning. + + This dataset aggregates five medical audio sources: ICBHI, KAUH, + CirCor, SPRSound, and ZCHSound. It pairs respiratory and cardiac + audio with 34,792 GPT-4o generated Question-Answer pairs. + + Args: + root: Root directory containing the audio files (.wav) and/or CSVs. + tables: Optional list of tables to load. Defaults to ["metadata"]. + dataset_name: Optional name of the dataset. Defaults to "caresound". + config_path: Optional path to the configuration file. + + Example: + >>> from pyhealth.datasets import CaReSoundDataset + >>> example_dataset = CaReSoundDataset(root="/Users/rahuld/Downloads/CaReAQA/datasets") + >>> example_dataset.stats() + """ + + def __init__( + self, + root: str, + tables: List[str] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "caresound.yaml" + + # 1. Prepare standardized CSV (handles all local/API edge cases) + pyhealth_csv = os.path.join(root, "caresound_metadata.csv") + if not os.path.exists(pyhealth_csv): + self.prepare_metadata(root) + + # 2. Resolve local audio paths dynamically + self.audio_path_map = self._resolve_audio_paths(root) + + # 3. Define the default table mapped in the YAML + default_tables = ["metadata"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "caresound", + config_path=config_path, + **kwargs, + ) + + @staticmethod + def prepare_metadata(root: str) -> None: + """Prepares QA metadata from local ZIP/CSV drops or downloads via HF API.""" + output_path = os.path.join(root, "caresound_metadata.csv") + + train_csv = os.path.join(root, "CaReSoundQA_train.csv") + test_csv = os.path.join(root, "CaReSoundQA_test.csv") + full_csv = os.path.join(root, "CaReSoundQA.csv") + + df_master = None + + # Scenario A: User manually downloaded both Train and Test CSVs + if os.path.exists(train_csv) and os.path.exists(test_csv): + logger.info("Found local train/test CSVs. Merging...") + df_train, df_test = pd.read_csv(train_csv), pd.read_csv(test_csv) + df_train["hf_split"], df_test["hf_split"] = "train", "test" + df_master = pd.concat([df_train, df_test], ignore_index=True) + + # Scenario B: User manually downloaded ONLY Train CSV + elif os.path.exists(train_csv): + logger.warning( + "Found local train CSV, but test is missing. Using train only." + ) + df_master = pd.read_csv(train_csv) + df_master["hf_split"] = "train" + + # Scenario C: User manually downloaded the Master CSV + elif os.path.exists(full_csv): + logger.info(f"Found master CSV: {full_csv}.") + df_master = pd.read_csv(full_csv) + if "hf_split" not in df_master.columns: + df_master["hf_split"] = "unknown" + + # Scenario D: Fallback to Hugging Face API + else: + try: + from datasets import load_dataset + + logger.info( + "Local metadata not found. Fetching from tsnngw/CaReSound..." + ) + dataset = load_dataset("tsnngw/CaReSound") + + df_train = dataset["train"].to_pandas() + df_train["hf_split"] = "train" + + # Catch in case the dataset structure changes on HF + if "test" in dataset: + df_test = dataset["test"].to_pandas() + df_test["hf_split"] = "test" + df_master = pd.concat([df_train, df_test], ignore_index=True) + else: + df_master = df_train + + except ImportError: + logger.error( + "The 'datasets' library is required. Run: pip install datasets" + ) + raise + except Exception as e: + logger.error(f"Failed to fetch metadata: {e}") + raise + + # ---> MINIMAL FIX: Inject audio paths right before saving <--- + audio_map = {} + for path in Path(root).rglob("*.wav"): + stem, path_str = path.stem, str(path).lower() + source = "Unknown" + if "icbhi" in path_str: + source = "ICBHI" + elif "circor" in path_str: + source = "CirCor" + elif "kauh" in path_str: + source = "KAUH" + elif "spr" in path_str: + source = "SPRSound" + elif "zch" in path_str: + source = "ZCHSound" + + audio_map[(source, stem)] = str(path.absolute()) + audio_map[(source, stem.split("_")[0])] = str(path.absolute()) + + df_master["audio_path"] = df_master.apply( + lambda r: audio_map.get( + (str(r.get("dataset", "Unknown")), str(r.get("patient_id", ""))), "" + ), + axis=1, + ) + + # Save the final CSV for the new PyHealth Engine to pick up automatically + df_master.to_csv(output_path, index=False) + logger.info( + f"Saved {len(df_master)} QA pairs with mapped audio to {output_path}" + ) + + def _resolve_audio_paths(self, root: str) -> Dict[tuple, str]: + """Maps .wav files using robust stem and prefix mapping.""" + audio_map = {} + wav_files = list(Path(root).rglob("*.wav")) + + if not wav_files: + logger.warning(f"No .wav files found in {root}.") + return audio_map + + for path in wav_files: + stem = path.stem + path_str = str(path).lower() + + source = "Unknown" + if "icbhi" in path_str: + source = "ICBHI" + elif "circor" in path_str: + source = "CirCor" + elif "kauh" in path_str: + source = "KAUH" + elif "spr" in path_str: + source = "SPRSound" + elif "zch" in path_str: + source = "ZCHSound" + + # 1. Primary Mapping: Exact filename match + audio_map[(source, stem)] = str(path.absolute()) + + # 2. Fallback Mapping: Base Patient ID match (e.g., '101' from '101_1b1') + base_id = stem.split("_")[0] + if (source, base_id) not in audio_map: + audio_map[(source, base_id)] = str(path.absolute()) + + return audio_map + + def parse_func(self) -> Dict[str, Any]: + """Merges tabular QA metadata with local audio paths.""" + csv_path = os.path.join(self.root, "caresound_metadata.csv") + df = pd.read_csv(csv_path) + + patients = {} + missing_sources = set() + + for _, row in df.iterrows(): + pid = str(row["patient_id"]) + source = str(row["dataset"]) + + audio_path = self.audio_path_map.get((source, pid)) + + if not audio_path: + missing_sources.add(source) + continue + + if pid not in patients: + patients[pid] = {"patient_id": pid, "visits": {}} + + visit_id = f"{source}_{pid}" + if visit_id not in patients[pid]["visits"]: + patients[pid]["visits"][visit_id] = { + "visit_id": visit_id, + "audio_path": audio_path, + "events": [], + } + + patients[pid]["visits"][visit_id]["events"].append( + { + "question": row.get("question", ""), + "answer": row.get("answer", ""), + "hf_split": row.get("hf_split", "unknown"), + } + ) + + if missing_sources: + logger.warning( + f"Audio files missing for datasets: {', '.join(missing_sources)}. " + "Only available multi-modal samples have been loaded." + ) + + return patients + + @property + def default_task(self): + from pyhealth.tasks import CaReSoundAQA + + return CaReSoundAQA() diff --git a/pyhealth/datasets/configs/caresound.yaml b/pyhealth/datasets/configs/caresound.yaml new file mode 100644 index 000000000..081b9b2f5 --- /dev/null +++ b/pyhealth/datasets/configs/caresound.yaml @@ -0,0 +1,12 @@ +version: "1.0" +tables: + metadata: + file_path: "caresound_metadata.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "dataset" + - "question" + - "answer" + - "hf_split" + - "audio_path" \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..9ba8173a2 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,6 +1,7 @@ from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction +from .caresound_tasks import CaReSoundAQA from .bmd_hs_disease_classification import BMDHSDiseaseClassification from .cardiology_detect import ( cardiology_isAD_fn, diff --git a/pyhealth/tasks/caresound_tasks.py b/pyhealth/tasks/caresound_tasks.py new file mode 100644 index 000000000..80666901a --- /dev/null +++ b/pyhealth/tasks/caresound_tasks.py @@ -0,0 +1,66 @@ +"""Audio Question Answering tasks for PyHealth. + +This module provides tasks for processing generative text answers +based on medical audio signals using the CaReSound dataset. +""" + +from typing import Any, Dict, List +from .base_task import BaseTask + + +class CaReSoundAQA(BaseTask): + """Task for Audio Question Answering on respiratory and cardiac sounds. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): Required inputs (audio_path, question). + output_schema (Dict[str, str]): Required outputs (answer). + + Example: + >>> from pyhealth.datasets import CaReSoundDataset + >>> from pyhealth.tasks import CaReSoundAQA + >>> example_dataset = CaReSoundDataset(root="/Users/rahuld/Downloads/CaReAQA/datasets") + >>> sample_dataset = example_dataset.set_task(CaReSoundAQA()) + """ + + task_name: str = "CaReSoundAQA" + input_schema: Dict[str, str] = { + "question": "text", + } + output_schema: Dict[str, str] = {"answer": "text"} + + def _safe_str(self, value: Any, default: str = "") -> str: + """Safely convert value to string, handling None and NaN.""" + if value is None or str(value).lower() == "nan": + return default + return str(value) + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a patient record to extract audio-QA samples.""" + samples: List[Dict[str, Any]] = [] + events = patient.get_events("metadata") + + for event in events: + # The new engine puts CSV columns into attr_dict + attr = getattr(event, "attr_dict", {}) + + audio_path = str(attr.get("audio_path", "")) + question = str(attr.get("question", "")) + answer = str(attr.get("answer", "")) + hf_split = str(attr.get("hf_split", "unknown")) + + if not audio_path or not question or not answer: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": f"v_{patient.patient_id}", + "audio_path": audio_path, + "question": question, + "answer": answer, + "original_hf_split": hf_split, + } + ) + + return samples diff --git a/tests/core/test_caresound.py b/tests/core/test_caresound.py new file mode 100644 index 000000000..23d21e359 --- /dev/null +++ b/tests/core/test_caresound.py @@ -0,0 +1,69 @@ +import unittest +import tempfile +import pandas as pd +from pathlib import Path +from pyhealth.datasets import CaReSoundDataset +from pyhealth.tasks import CaReSoundAQA + + +class TestCaReSoundDataset(unittest.TestCase): + """Test cases for the CaReSoundDataset using synthetic data.""" + + def setUp(self): + """Create a temporary directory with synthetic data.""" + self.test_dir = tempfile.TemporaryDirectory() + self.root = Path(self.test_dir.name) + + # 1. Create synthetic CSV (2 patients) + self.df = pd.DataFrame( + { + "patient_id": ["101", "102"], + "dataset": ["icbhi", "circor"], + "question": ["Is this normal?", "Any abnormalities?"], + "answer": ["Normal", "Abnormal"], + "hf_split": ["train", "test"], + "metadata/audio_path": ["icbhi_101.wav", "circor_102.wav"], + } + ) + self.df.to_csv(self.root / "caresound_metadata.csv", index=False) + + # 2. Create dummy audio files (just empty files are enough for path matching) + (self.root / "icbhi_101.wav").touch() + (self.root / "circor_102.wav").touch() + + def tearDown(self): + """Cleanup temporary directory.""" + self.test_dir.cleanup() + + def test_dataset_initialization(self): + dataset = CaReSoundDataset(root=str(self.root)) + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "caresound") + + def test_default_task(self): + dataset = CaReSoundDataset(root=str(self.root)) + self.assertIsInstance(dataset.default_task, CaReSoundAQA) + + def test_set_task(self): + dataset = CaReSoundDataset(root=str(self.root)) + task = CaReSoundAQA() + samples = dataset.set_task(task) + + # We expect 2 samples since we created 2 patients + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["patient_id"], "101") + + +class TestCaReSoundAQA(unittest.TestCase): + """Test cases for the CaReSoundAQA task schema.""" + + def setUp(self): + self.task = CaReSoundAQA() + + def test_task_attributes(self): + self.assertEqual(self.task.task_name, "CaReSoundAQA") + self.assertIn("question", self.task.input_schema) + + +if __name__ == "__main__": + unittest.main()