diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 8d9a59d21..fd543ad9c 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -236,6 +236,7 @@ Available Datasets
datasets/pyhealth.datasets.EHRShotDataset
datasets/pyhealth.datasets.Support2Dataset
datasets/pyhealth.datasets.BMDHSDataset
+ datasets/pyhealth.datasets.CaReSoundDataset
datasets/pyhealth.datasets.COVID19CXRDataset
datasets/pyhealth.datasets.ChestXray14Dataset
datasets/pyhealth.datasets.PhysioNetDeIDDataset
diff --git a/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst b/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst
new file mode 100644
index 000000000..821842b69
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.CaReSoundDataset.rst
@@ -0,0 +1,11 @@
+pyhealth.datasets.CaReSoundDataset
+================================
+
+The CaReSound dataset provides question and answer pairs for medical sounds. For more information see `CaReSound `_. This dataset was contributed as part of the CaReAQA: A Cardiac and Respiratory Audio Question Answering Model for Open-Ended Diagnostic Reasoning work (`arXiv:2505.01199 `_).
+
+.. autoclass:: pyhealth.datasets.CaReSoundDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
\ No newline at end of file
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 23a4e06e5..44d3adad0 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -230,3 +230,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC)
Cancer Survival Prediction (TCGA)
Cancer Mutation Burden (TCGA)
+ CaReSound
diff --git a/docs/api/tasks/pyhealth.tasks.CaReSoundAQA b/docs/api/tasks/pyhealth.tasks.CaReSoundAQA
new file mode 100644
index 000000000..f4b4416b0
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.CaReSoundAQA
@@ -0,0 +1,7 @@
+pyhealth.tasks.CaReSoundAQA
+==============================================
+
+.. autoclass:: pyhealth.tasks.CaReSoundAQA
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py b/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py
new file mode 100644
index 000000000..889c45858
--- /dev/null
+++ b/examples/CaReSoundDataset_CaReSoundAQA_CaReAQA.py
@@ -0,0 +1,489 @@
+from pyhealth.datasets import CaReSoundDataset
+from pyhealth.tasks import CaReSoundAQA
+import librosa
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import timm
+import numpy as np
+from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
+from peft import get_peft_model, LoraConfig
+
+# ==============================================================================
+# MODEL SETUP INSTRUCTIONS
+# ==============================================================================
+# 1. DOWNLOAD: Obtain the 'CaReAQAmodel.pt' weights file from the official
+# project source (e.g., Hugging Face or the provided Google Drive link).
+#
+# 2. DIRECTORY STRUCTURE: Create the following folder path on your Mac:
+# /Users/rahuld/Downloads/CaReAQA/CaReAQAModel/
+#
+# 3. PLACEMENT: Move the downloaded file into that folder.
+# Ensure the file is named EXACTLY 'CaReAQAmodel.pt'.
+#
+# 4. VERIFICATION: Your final file path must match the variable below:
+# local_careqa_path = "/Users/rahuld/Downloads/CaReAQA/CaReAQAModel/CaReAQAmodel.pt"
+# ==============================================================================
+
+# ==============================================================================
+# DATASET SETUP INSTRUCTIONS
+# ==============================================================================
+# 1. DOWNLOAD: Obtain the 5 source audio datasets: ICBHI, KAUH, CirCor,
+# SPRSound, and ZCHSound from their respective open-access repositories.
+#
+# 2. DIRECTORY STRUCTURE: Maintain the following folder path on your Mac:
+# /Users/rahuld/Downloads/CaReAQA/datasets/
+#
+# 3. PLACEMENT: Ensure audio files (.wav) are located within their respective
+# folders (e.g., 'ICBHI Respiratory Sound Dataset', 'ZCHSound', etc.).
+# The mapping logic will scan these directories to link audio to QA pairs.
+#
+# 4. VERIFICATION: Your current 'ls' output confirms the following layout:
+# /Users/rahuld/Downloads/CaReAQA/datasets/
+# ├── CirCor Pediatric Heart Sound Dataset/
+# ├── ICBHI Respiratory Sound Dataset/
+# ├── KAUH Respiratory Dataset/
+# ├── SPRSound Pediatric Respiratory Dataset/
+# ├── ZCHSound/
+# └── caresound_metadata.csv
+# ==============================================================================
+
+local_careqa_path = "/Users/rahuld/Downloads/CaReAQA/CaReAQAModel/CaReAQAmodel.pt"
+path_to_dataset = "/Users/rahuld/Downloads/CaReAQA/datasets"
+local_llama_dir = "/Users/rahuld/Downloads/meta-llama/Llama-3.2-3B"
+
+# set variable to true to use model to generate answer
+generateModelAnswer = False
+
+
+# ==========================================
+# 1. MODEL CODE START AI GENERATED BASED ON .pt file
+# ==========================================
+
+
+# ==========================================
+# 1. CLIPCAP TRANSFORMER MAPPER (PREFIX PROJECTOR)
+# ==========================================
+class Mlp(nn.Module):
+ def __init__(self, in_dim, h_dim, out_d=None, act=nn.GELU(), dropout=0.0):
+ super().__init__()
+ out_d = out_d if out_d is not None else in_dim
+ self.fc1 = nn.Linear(in_dim, h_dim)
+ self.act = act
+ self.fc2 = nn.Linear(h_dim, out_d)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ return self.dropout(self.fc2(self.dropout(self.act(self.fc1(x)))))
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.0):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim_self // num_heads
+ self.scale = head_dim**-0.5
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
+ self.project = nn.Linear(dim_self, dim_self)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, y=None, mask=None):
+ y = y if y is not None else x
+ b, n, c = x.shape
+ _, m, d = y.shape
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
+ kv = self.to_keys_values(y).reshape(
+ b, m, 2, self.num_heads, c // self.num_heads
+ )
+ keys, values = kv[:, :, 0], kv[:, :, 1]
+ attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale
+ if mask is not None:
+ if mask.dim() == 2:
+ mask = mask.unsqueeze(1)
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
+ attention = attention.softmax(dim=2)
+ out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c)
+ return self.project(out), attention
+
+
+class TransformerLayer(nn.Module):
+ def __init__(
+ self,
+ dim_self,
+ dim_ref,
+ num_heads,
+ mlp_ratio=4.0,
+ bias=False,
+ dropout=0.0,
+ act=nn.GELU(),
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim_self)
+ self.attn = MultiHeadAttention(
+ dim_self, dim_ref, num_heads, bias=bias, dropout=dropout
+ )
+ self.norm2 = norm_layer(dim_self)
+ self.mlp = Mlp(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
+
+ def forward(self, x, y=None, mask=None):
+ x_, _ = self.attn(self.norm1(x), y, mask)
+ x = x + x_
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class TransformerMapper(nn.Module):
+ def __init__(
+ self, dim_clip, dim_embedding, prefix_length, clip_length, num_layers=8
+ ):
+ super().__init__()
+ self.clip_length = clip_length
+ self.transformer = nn.ModuleList(
+ [
+ TransformerLayer(
+ dim_embedding,
+ dim_embedding,
+ 8,
+ 2.0,
+ act=nn.GELU(),
+ norm_layer=nn.LayerNorm,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
+ self.prefix_const = nn.Parameter(
+ torch.randn(prefix_length, dim_embedding), requires_grad=True
+ )
+
+ def forward(self, x):
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
+ prefix = self.prefix_const.unsqueeze(0).expand(
+ x.shape[0], *self.prefix_const.shape
+ )
+ prefix = torch.cat((x, prefix), dim=1)
+ for layer in self.transformer:
+ prefix = layer(prefix)
+ return prefix[:, self.clip_length :]
+
+
+# ==========================================
+# 2. AUDIO ENCODER (EFFICIENTNET)
+# ==========================================
+class AudioEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.cnn1 = nn.Conv2d(1, 3, kernel_size=3, padding=1)
+ # Using timm to automatically build the EfficientNet architecture matching the state_dict
+ self.efficientnet = timm.create_model(
+ "efficientnet_b0", pretrained=False, num_classes=0
+ )
+
+ def forward(self, x):
+ # Format the 3D spectrogram into a 4D batch for the CNN
+ if x.dim() == 3:
+ x = x.unsqueeze(1)
+ x = self.cnn1(x)
+ x = self.efficientnet.forward_features(x)
+ # Global Average Pool to turn feature map into a 1280-dim vector
+ x = x.mean(dim=[2, 3])
+ return x
+
+
+class AudioModelWrapper(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.encoder = AudioEncoder()
+
+ def extract_feature(self, x, dim=1280):
+ return self.encoder(x)
+
+
+# ==========================================
+# 3. MAIN MODEL WRAPPER
+# ==========================================
+class AudioQAModel(nn.Module):
+ def __init__(
+ self,
+ llm_type,
+ opera_checkpoint_path=None,
+ prefix_length=8,
+ clip_length=1,
+ setting="lora",
+ mapping_type="Transformer",
+ fine_tune_opera=True,
+ args=None,
+ ):
+ super().__init__()
+
+ # Load the base Llama model
+ print(f"Loading Base LLM: {llm_type}...")
+ self.llm = AutoModelForCausalLM.from_pretrained(
+ llm_type, torch_dtype=torch.float16
+ )
+
+ # Hook up LoRA adapters matching the model.pt state_dict shapes
+ if setting == "lora":
+ lora_config = LoraConfig(
+ r=8,
+ target_modules=[
+ "q_proj",
+ "v_proj",
+ "k_proj",
+ "o_proj",
+ "up_proj",
+ "down_proj",
+ "gate_proj",
+ ],
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+ self.llm = get_peft_model(self.llm, lora_config)
+
+ # Hook up the rebuilt Audio and Projection modules
+ self.audio_model = AudioModelWrapper()
+
+ dim_embedding = (
+ self.llm.config.hidden_size if hasattr(self.llm, "config") else 3072
+ )
+ self.prefix_project = TransformerMapper(
+ dim_clip=1280,
+ dim_embedding=dim_embedding,
+ prefix_length=prefix_length,
+ clip_length=clip_length,
+ num_layers=8,
+ )
+
+
+# ==========================================
+# 1. MODEL CODE END
+# ==========================================
+
+
+# ==========================================
+# 1. LOAD MODEL (MODIFIED FOR LOCAL FILE & MAC)
+# ==========================================
+def load_careqa_model_local(
+ local_model_path, llm_type="meta-llama/Llama-3.2-3B", prefix_length=8
+):
+ print("Initializing model architecture...")
+ model = AudioQAModel(
+ llm_type=llm_type,
+ opera_checkpoint_path=None,
+ prefix_length=prefix_length,
+ clip_length=1,
+ setting="lora",
+ mapping_type="Transformer",
+ fine_tune_opera=True,
+ args=None,
+ ).eval()
+
+ # Automatically detect Apple Silicon (MPS), GPU, or CPU
+ if False:
+ device = torch.device("mps")
+ print("Using Apple Silicon (MPS) for acceleration!")
+ elif False:
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+
+ model = model.to(device)
+
+ print(f"Loading weights from {local_model_path}...")
+ state_dict = torch.load(local_model_path, map_location="cpu")
+
+ # Extract nested state_dict if it exists
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ model.load_state_dict(state_dict, strict=False)
+ print("Model loaded successfully!\n")
+ return model, device
+
+
+# ==========================================
+# 2. PREPROCESS AUDIO
+# ==========================================
+def preprocess_audio(audio_path, device, sr=16000):
+ print(f"Processing audio file: {audio_path}")
+ raw_audio, sr = librosa.load(audio_path, sr=sr)
+ mel_spec = librosa.feature.melspectrogram(
+ y=raw_audio, sr=sr, n_fft=1024, hop_length=512, n_mels=64
+ )
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
+
+ # Convert to tensor and move to correct device (MPS/CUDA/CPU)
+ audio_tensor = (
+ torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).to(device)
+ )
+ return audio_tensor
+
+
+# ==========================================
+# 3. GENERATE ANSWER
+# ==========================================
+def generate_answer(
+ model,
+ tokenizer,
+ audio_tensor,
+ question,
+ device,
+ prefix_length=8,
+ audio_feature_dim=1280,
+):
+ # 1. Extract audio features
+ with torch.no_grad():
+ audio_features = model.audio_model.extract_feature(
+ audio_tensor, dim=audio_feature_dim
+ )
+ projected_prefix = model.prefix_project(audio_features)
+
+ # 2. Tokenize the text prompts
+ q_prefix = tokenizer.encode("question: ", add_special_tokens=False)
+ q_tokens = tokenizer.encode(question, add_special_tokens=False)
+ a_prefix = tokenizer.encode(" answer", add_special_tokens=False)
+
+ input_tokens = q_prefix + q_tokens + a_prefix
+
+ # 3. Create input IDs
+ input_ids = torch.tensor(
+ [input_tokens + [tokenizer.eos_token_id] * prefix_length], dtype=torch.long
+ ).to(device)
+ attention_mask = torch.ones_like(input_ids)
+
+ # 4. Insert the audio projection (FIXED FOR APPLE SILICON)
+ # Adding .clone() prevents the Apple MPS silent deadlock!
+ input_embeds = model.llm.get_input_embeddings()(input_ids).clone()
+ input_embeds[
+ 0, len(q_prefix + q_tokens) : len(q_prefix + q_tokens) + prefix_length
+ ] = projected_prefix[0]
+
+ # 5. Generate the response
+ print(">>> Firing up the LLaMA generation engine... (this should take < 2 mins)")
+ with torch.no_grad():
+ output_ids = model.llm.generate(
+ inputs_embeds=input_embeds,
+ attention_mask=attention_mask,
+ max_new_tokens=20,
+ do_sample=False,
+ use_cache=False, # FIXED: KV Cache can freeze when using custom inputs_embeds on Mac
+ pad_token_id=tokenizer.eos_token_id, # Silences the warning you got earlier
+ )
+ print(">>> Generation complete!")
+
+ answer = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
+ return answer
+
+
+def get_tokenizer(model_dir: str):
+ """Loads the LLaMA tokenizer from a local directory."""
+ print("Loading local LLaMA Tokenizer...")
+ return AutoTokenizer.from_pretrained(model_dir)
+
+
+def get_base_llm(model_dir: str):
+ """Loads the base LLaMA model with 4-bit quantization."""
+ from transformers import AutoModelForCausalLM, QuantoConfig
+
+ print("Loading Llama 3.2 3B in 4-bit mode...")
+ quant_config = QuantoConfig(weights="int8")
+ llm_model = AutoModelForCausalLM.from_pretrained(
+ model_dir,
+ quantization_config=quant_config,
+ # device_map="mps" # Uncomment to force Apple Silicon GPU allocation if supported
+ )
+ return llm_model
+
+
+def get_careqa_model(careqa_path: str, llama_dir: str):
+ """Loads the custom AudioQAModel and assigns it to the correct device."""
+ # This calls your previously defined load_careqa_model_local method
+ model, device = load_careqa_model_local(
+ local_model_path=careqa_path, llm_type=llama_dir
+ )
+ return model, device
+
+
+# ==========================================
+# 4. GET ANSWER WRAPPER (Add this above __main__)
+# ==========================================
+def get_answer(question: str, audio_path: str, model, tokenizer, device) -> str:
+ print("\n" + "-" * 50)
+ print(f"Question: {question}")
+ print("-" * 50)
+
+ audio_tensor = preprocess_audio(audio_path, device)
+
+ print("Generating answer...")
+ answer = generate_answer(model, tokenizer, audio_tensor, question, device)
+
+ print("-" * 50)
+ print(f"Answer: {answer}")
+ print("-" * 50)
+
+ return answer
+
+
+def print_stats(example_dataset: CaReSoundDataset):
+ print("\n" + "=" * 50)
+ print("DATASET STATS")
+ print("=" * 50)
+ example_dataset.stats()
+
+
+def print_sample_i(sample_dataset: CaReSoundAQA, i: int):
+ print("\n" + "=" * 50)
+ print(f"Sample at index {i}")
+ print("=" * 50)
+ print(sample_dataset[i])
+ print(sample_dataset[i]["audio_path"])
+
+
+def get_raw_audio(sample_dataset: CaReSoundAQA, i, sr=16000):
+ audio_path = sample_dataset[i]["audio_path"]
+
+ # 2. Verify the file exists before trying to load it
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Could not find audio file at: {audio_path}")
+
+ # 3. Load the raw audio
+ # sr=16000 ensures it matches the sample rate used in medical audio tasks
+ audio, sample_rate = librosa.load(audio_path, sr=sr)
+
+ return audio, sample_rate
+
+
+if __name__ == "__main__":
+
+ example_dataset = CaReSoundDataset(root=path_to_dataset)
+ sample_dataset = example_dataset.set_task(CaReSoundAQA())
+ # testdataset()
+
+ # test dataset
+ print_stats(example_dataset)
+ print_sample_i(sample_dataset, 10)
+ get_raw_audio(sample_dataset, 10)
+
+ if generateModelAnswer:
+ # 2. Extract Data for Inference
+ test_question = sample_dataset[10]["question"]
+ target_audio_path = sample_dataset[10]["audio_path"] # Renamed for clarity
+ ground_truth = sample_dataset[10]["answer"]
+
+ # 3. Load Models
+ tokenizer = get_tokenizer(local_llama_dir)
+ careqa_model, device = get_careqa_model(local_careqa_path, local_llama_dir)
+
+ # 4. Run Inference (Fixed the variable name mismatch here)
+ generated_answer = get_answer(
+ question=test_question,
+ audio_path=target_audio_path,
+ model=careqa_model,
+ tokenizer=tokenizer,
+ device=device,
+ )
+
+ print(f"\nGround Truth was: {ground_truth}")
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 50b1b3887..d2e51ecf6 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -48,6 +48,7 @@ def __init__(self, *args, **kwargs):
from .base_dataset import BaseDataset
from .cardiology import CardiologyDataset
+from .caresound import CaReSoundDataset
from .chestxray14 import ChestXray14Dataset
from .clinvar import ClinVarDataset
from .cosmic import COSMICDataset
diff --git a/pyhealth/datasets/caresound.py b/pyhealth/datasets/caresound.py
new file mode 100644
index 000000000..e04021230
--- /dev/null
+++ b/pyhealth/datasets/caresound.py
@@ -0,0 +1,248 @@
+"""CaReSound dataset for PyHealth.
+
+This module provides the CaReSoundDataset class for loading and processing
+the CaReSound benchmark data for Audio Question Answering (AQA) tasks.
+"""
+
+import logging
+import os
+from pathlib import Path
+from typing import List, Optional, Dict, Any
+
+import pandas as pd
+from .base_dataset import BaseDataset
+
+logger = logging.getLogger(__name__)
+
+
+class CaReSoundDataset(BaseDataset):
+ """CaReSound dataset for open-ended diagnostic reasoning.
+
+ This dataset aggregates five medical audio sources: ICBHI, KAUH,
+ CirCor, SPRSound, and ZCHSound. It pairs respiratory and cardiac
+ audio with 34,792 GPT-4o generated Question-Answer pairs.
+
+ Args:
+ root: Root directory containing the audio files (.wav) and/or CSVs.
+ tables: Optional list of tables to load. Defaults to ["metadata"].
+ dataset_name: Optional name of the dataset. Defaults to "caresound".
+ config_path: Optional path to the configuration file.
+
+ Example:
+ >>> from pyhealth.datasets import CaReSoundDataset
+ >>> example_dataset = CaReSoundDataset(root="/Users/rahuld/Downloads/CaReAQA/datasets")
+ >>> example_dataset.stats()
+ """
+
+ def __init__(
+ self,
+ root: str,
+ tables: List[str] = None,
+ dataset_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ if config_path is None:
+ logger.info("No config path provided, using default config")
+ config_path = Path(__file__).parent / "configs" / "caresound.yaml"
+
+ # 1. Prepare standardized CSV (handles all local/API edge cases)
+ pyhealth_csv = os.path.join(root, "caresound_metadata.csv")
+ if not os.path.exists(pyhealth_csv):
+ self.prepare_metadata(root)
+
+ # 2. Resolve local audio paths dynamically
+ self.audio_path_map = self._resolve_audio_paths(root)
+
+ # 3. Define the default table mapped in the YAML
+ default_tables = ["metadata"]
+ tables = default_tables + (tables or [])
+
+ super().__init__(
+ root=root,
+ tables=tables,
+ dataset_name=dataset_name or "caresound",
+ config_path=config_path,
+ **kwargs,
+ )
+
+ @staticmethod
+ def prepare_metadata(root: str) -> None:
+ """Prepares QA metadata from local ZIP/CSV drops or downloads via HF API."""
+ output_path = os.path.join(root, "caresound_metadata.csv")
+
+ train_csv = os.path.join(root, "CaReSoundQA_train.csv")
+ test_csv = os.path.join(root, "CaReSoundQA_test.csv")
+ full_csv = os.path.join(root, "CaReSoundQA.csv")
+
+ df_master = None
+
+ # Scenario A: User manually downloaded both Train and Test CSVs
+ if os.path.exists(train_csv) and os.path.exists(test_csv):
+ logger.info("Found local train/test CSVs. Merging...")
+ df_train, df_test = pd.read_csv(train_csv), pd.read_csv(test_csv)
+ df_train["hf_split"], df_test["hf_split"] = "train", "test"
+ df_master = pd.concat([df_train, df_test], ignore_index=True)
+
+ # Scenario B: User manually downloaded ONLY Train CSV
+ elif os.path.exists(train_csv):
+ logger.warning(
+ "Found local train CSV, but test is missing. Using train only."
+ )
+ df_master = pd.read_csv(train_csv)
+ df_master["hf_split"] = "train"
+
+ # Scenario C: User manually downloaded the Master CSV
+ elif os.path.exists(full_csv):
+ logger.info(f"Found master CSV: {full_csv}.")
+ df_master = pd.read_csv(full_csv)
+ if "hf_split" not in df_master.columns:
+ df_master["hf_split"] = "unknown"
+
+ # Scenario D: Fallback to Hugging Face API
+ else:
+ try:
+ from datasets import load_dataset
+
+ logger.info(
+ "Local metadata not found. Fetching from tsnngw/CaReSound..."
+ )
+ dataset = load_dataset("tsnngw/CaReSound")
+
+ df_train = dataset["train"].to_pandas()
+ df_train["hf_split"] = "train"
+
+ # Catch in case the dataset structure changes on HF
+ if "test" in dataset:
+ df_test = dataset["test"].to_pandas()
+ df_test["hf_split"] = "test"
+ df_master = pd.concat([df_train, df_test], ignore_index=True)
+ else:
+ df_master = df_train
+
+ except ImportError:
+ logger.error(
+ "The 'datasets' library is required. Run: pip install datasets"
+ )
+ raise
+ except Exception as e:
+ logger.error(f"Failed to fetch metadata: {e}")
+ raise
+
+ # ---> MINIMAL FIX: Inject audio paths right before saving <---
+ audio_map = {}
+ for path in Path(root).rglob("*.wav"):
+ stem, path_str = path.stem, str(path).lower()
+ source = "Unknown"
+ if "icbhi" in path_str:
+ source = "ICBHI"
+ elif "circor" in path_str:
+ source = "CirCor"
+ elif "kauh" in path_str:
+ source = "KAUH"
+ elif "spr" in path_str:
+ source = "SPRSound"
+ elif "zch" in path_str:
+ source = "ZCHSound"
+
+ audio_map[(source, stem)] = str(path.absolute())
+ audio_map[(source, stem.split("_")[0])] = str(path.absolute())
+
+ df_master["audio_path"] = df_master.apply(
+ lambda r: audio_map.get(
+ (str(r.get("dataset", "Unknown")), str(r.get("patient_id", ""))), ""
+ ),
+ axis=1,
+ )
+
+ # Save the final CSV for the new PyHealth Engine to pick up automatically
+ df_master.to_csv(output_path, index=False)
+ logger.info(
+ f"Saved {len(df_master)} QA pairs with mapped audio to {output_path}"
+ )
+
+ def _resolve_audio_paths(self, root: str) -> Dict[tuple, str]:
+ """Maps .wav files using robust stem and prefix mapping."""
+ audio_map = {}
+ wav_files = list(Path(root).rglob("*.wav"))
+
+ if not wav_files:
+ logger.warning(f"No .wav files found in {root}.")
+ return audio_map
+
+ for path in wav_files:
+ stem = path.stem
+ path_str = str(path).lower()
+
+ source = "Unknown"
+ if "icbhi" in path_str:
+ source = "ICBHI"
+ elif "circor" in path_str:
+ source = "CirCor"
+ elif "kauh" in path_str:
+ source = "KAUH"
+ elif "spr" in path_str:
+ source = "SPRSound"
+ elif "zch" in path_str:
+ source = "ZCHSound"
+
+ # 1. Primary Mapping: Exact filename match
+ audio_map[(source, stem)] = str(path.absolute())
+
+ # 2. Fallback Mapping: Base Patient ID match (e.g., '101' from '101_1b1')
+ base_id = stem.split("_")[0]
+ if (source, base_id) not in audio_map:
+ audio_map[(source, base_id)] = str(path.absolute())
+
+ return audio_map
+
+ def parse_func(self) -> Dict[str, Any]:
+ """Merges tabular QA metadata with local audio paths."""
+ csv_path = os.path.join(self.root, "caresound_metadata.csv")
+ df = pd.read_csv(csv_path)
+
+ patients = {}
+ missing_sources = set()
+
+ for _, row in df.iterrows():
+ pid = str(row["patient_id"])
+ source = str(row["dataset"])
+
+ audio_path = self.audio_path_map.get((source, pid))
+
+ if not audio_path:
+ missing_sources.add(source)
+ continue
+
+ if pid not in patients:
+ patients[pid] = {"patient_id": pid, "visits": {}}
+
+ visit_id = f"{source}_{pid}"
+ if visit_id not in patients[pid]["visits"]:
+ patients[pid]["visits"][visit_id] = {
+ "visit_id": visit_id,
+ "audio_path": audio_path,
+ "events": [],
+ }
+
+ patients[pid]["visits"][visit_id]["events"].append(
+ {
+ "question": row.get("question", ""),
+ "answer": row.get("answer", ""),
+ "hf_split": row.get("hf_split", "unknown"),
+ }
+ )
+
+ if missing_sources:
+ logger.warning(
+ f"Audio files missing for datasets: {', '.join(missing_sources)}. "
+ "Only available multi-modal samples have been loaded."
+ )
+
+ return patients
+
+ @property
+ def default_task(self):
+ from pyhealth.tasks import CaReSoundAQA
+
+ return CaReSoundAQA()
diff --git a/pyhealth/datasets/configs/caresound.yaml b/pyhealth/datasets/configs/caresound.yaml
new file mode 100644
index 000000000..081b9b2f5
--- /dev/null
+++ b/pyhealth/datasets/configs/caresound.yaml
@@ -0,0 +1,12 @@
+version: "1.0"
+tables:
+ metadata:
+ file_path: "caresound_metadata.csv"
+ patient_id: "patient_id"
+ timestamp: null
+ attributes:
+ - "dataset"
+ - "question"
+ - "answer"
+ - "hf_split"
+ - "audio_path"
\ No newline at end of file
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index a32618f9c..9ba8173a2 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -1,6 +1,7 @@
from .base_task import BaseTask
from .benchmark_ehrshot import BenchmarkEHRShot
from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction
+from .caresound_tasks import CaReSoundAQA
from .bmd_hs_disease_classification import BMDHSDiseaseClassification
from .cardiology_detect import (
cardiology_isAD_fn,
diff --git a/pyhealth/tasks/caresound_tasks.py b/pyhealth/tasks/caresound_tasks.py
new file mode 100644
index 000000000..80666901a
--- /dev/null
+++ b/pyhealth/tasks/caresound_tasks.py
@@ -0,0 +1,66 @@
+"""Audio Question Answering tasks for PyHealth.
+
+This module provides tasks for processing generative text answers
+based on medical audio signals using the CaReSound dataset.
+"""
+
+from typing import Any, Dict, List
+from .base_task import BaseTask
+
+
+class CaReSoundAQA(BaseTask):
+ """Task for Audio Question Answering on respiratory and cardiac sounds.
+
+ Attributes:
+ task_name (str): The name of the task.
+ input_schema (Dict[str, str]): Required inputs (audio_path, question).
+ output_schema (Dict[str, str]): Required outputs (answer).
+
+ Example:
+ >>> from pyhealth.datasets import CaReSoundDataset
+ >>> from pyhealth.tasks import CaReSoundAQA
+ >>> example_dataset = CaReSoundDataset(root="/Users/rahuld/Downloads/CaReAQA/datasets")
+ >>> sample_dataset = example_dataset.set_task(CaReSoundAQA())
+ """
+
+ task_name: str = "CaReSoundAQA"
+ input_schema: Dict[str, str] = {
+ "question": "text",
+ }
+ output_schema: Dict[str, str] = {"answer": "text"}
+
+ def _safe_str(self, value: Any, default: str = "") -> str:
+ """Safely convert value to string, handling None and NaN."""
+ if value is None or str(value).lower() == "nan":
+ return default
+ return str(value)
+
+ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
+ """Process a patient record to extract audio-QA samples."""
+ samples: List[Dict[str, Any]] = []
+ events = patient.get_events("metadata")
+
+ for event in events:
+ # The new engine puts CSV columns into attr_dict
+ attr = getattr(event, "attr_dict", {})
+
+ audio_path = str(attr.get("audio_path", ""))
+ question = str(attr.get("question", ""))
+ answer = str(attr.get("answer", ""))
+ hf_split = str(attr.get("hf_split", "unknown"))
+
+ if not audio_path or not question or not answer:
+ continue
+
+ samples.append(
+ {
+ "patient_id": patient.patient_id,
+ "visit_id": f"v_{patient.patient_id}",
+ "audio_path": audio_path,
+ "question": question,
+ "answer": answer,
+ "original_hf_split": hf_split,
+ }
+ )
+
+ return samples
diff --git a/tests/core/test_caresound.py b/tests/core/test_caresound.py
new file mode 100644
index 000000000..23d21e359
--- /dev/null
+++ b/tests/core/test_caresound.py
@@ -0,0 +1,69 @@
+import unittest
+import tempfile
+import pandas as pd
+from pathlib import Path
+from pyhealth.datasets import CaReSoundDataset
+from pyhealth.tasks import CaReSoundAQA
+
+
+class TestCaReSoundDataset(unittest.TestCase):
+ """Test cases for the CaReSoundDataset using synthetic data."""
+
+ def setUp(self):
+ """Create a temporary directory with synthetic data."""
+ self.test_dir = tempfile.TemporaryDirectory()
+ self.root = Path(self.test_dir.name)
+
+ # 1. Create synthetic CSV (2 patients)
+ self.df = pd.DataFrame(
+ {
+ "patient_id": ["101", "102"],
+ "dataset": ["icbhi", "circor"],
+ "question": ["Is this normal?", "Any abnormalities?"],
+ "answer": ["Normal", "Abnormal"],
+ "hf_split": ["train", "test"],
+ "metadata/audio_path": ["icbhi_101.wav", "circor_102.wav"],
+ }
+ )
+ self.df.to_csv(self.root / "caresound_metadata.csv", index=False)
+
+ # 2. Create dummy audio files (just empty files are enough for path matching)
+ (self.root / "icbhi_101.wav").touch()
+ (self.root / "circor_102.wav").touch()
+
+ def tearDown(self):
+ """Cleanup temporary directory."""
+ self.test_dir.cleanup()
+
+ def test_dataset_initialization(self):
+ dataset = CaReSoundDataset(root=str(self.root))
+ self.assertIsNotNone(dataset)
+ self.assertEqual(dataset.dataset_name, "caresound")
+
+ def test_default_task(self):
+ dataset = CaReSoundDataset(root=str(self.root))
+ self.assertIsInstance(dataset.default_task, CaReSoundAQA)
+
+ def test_set_task(self):
+ dataset = CaReSoundDataset(root=str(self.root))
+ task = CaReSoundAQA()
+ samples = dataset.set_task(task)
+
+ # We expect 2 samples since we created 2 patients
+ self.assertEqual(len(samples), 2)
+ self.assertEqual(samples[0]["patient_id"], "101")
+
+
+class TestCaReSoundAQA(unittest.TestCase):
+ """Test cases for the CaReSoundAQA task schema."""
+
+ def setUp(self):
+ self.task = CaReSoundAQA()
+
+ def test_task_attributes(self):
+ self.assertEqual(self.task.task_name, "CaReSoundAQA")
+ self.assertIn("question", self.task.input_schema)
+
+
+if __name__ == "__main__":
+ unittest.main()