diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..e75b9dfef 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -233,6 +233,7 @@ Available Datasets datasets/pyhealth.datasets.DREAMTDataset datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset + datasets/pyhealth.datasets.DSADataset datasets/pyhealth.datasets.EHRShotDataset datasets/pyhealth.datasets.Support2Dataset datasets/pyhealth.datasets.BMDHSDataset diff --git a/docs/api/datasets/pyhealth.datasets.DSADataset.rst b/docs/api/datasets/pyhealth.datasets.DSADataset.rst new file mode 100644 index 000000000..8b69dc25a --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.DSADataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.DSADataset +=================================== + +The Daily and Sports Activities (DSA) dataset. + +Each of the 19 activities is performed by eight subjects (4 female, 4 male, between the ages 20 and 30) for 5 minutes. + +.. autoclass:: pyhealth.datasets.DSADataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..f476ab546 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -203,4 +203,5 @@ API Reference models/pyhealth.models.VisionEmbeddingModel models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT + models/pyhealth.models.AdaptiveTransferModel models/pyhealth.models.unified_multimodal_embedding_docs diff --git a/docs/api/models/pyhealth.models.AdaptiveTransferModel.rst b/docs/api/models/pyhealth.models.AdaptiveTransferModel.rst new file mode 100644 index 000000000..f5ba02a62 --- /dev/null +++ b/docs/api/models/pyhealth.models.AdaptiveTransferModel.rst @@ -0,0 +1,11 @@ +pyhealth.models.AdaptiveTransferModel +=================================== + +Adaptive transfer model for multi-source time-series classification. + +This model is inspired by "Daily Physical Activity Monitoring: Adaptive Learning from Multi-Source Motion Sensor Data". + +.. autoclass:: pyhealth.models.AdaptiveTransferModel + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..b3a1c16d6 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -206,6 +206,7 @@ Available Tasks :maxdepth: 3 Base Task + DSA Activity Classification In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding Cardiology Detection diff --git a/docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst b/docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst new file mode 100644 index 000000000..0c4990061 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DSAActivityClassification.rst @@ -0,0 +1,11 @@ +pyhealth.tasks.DSAActivityClassification +=================================== + +Multi-class activity-recognition task for the Daily and Sports Activities (DSA) dataset. + +Each of the 19 activities is performed by eight subjects (4 female, 4 male, between the ages 20 and 30) for 5 minutes. + +.. autoclass:: pyhealth.tasks.DSAActivityClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/dsa_adaptive_transfer_replication.ipynb b/examples/dsa_adaptive_transfer_replication.ipynb new file mode 100644 index 000000000..26e0a566b --- /dev/null +++ b/examples/dsa_adaptive_transfer_replication.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Adaptive Transfer Learning for Daily Physical Activity Monitoring\n", + "\n", + "This notebook reproduces the core transfer learning framework and planned ablations/extensions from the paper:\n", + "*Daily Physical Activity Monitoring: Adaptive Learning from Multi-Source Motion Sensor Data*.\n", + "\n", + "The framework consists of:\n", + "1. **Domain Similarity Computation:** IPD between target and source domains.\n", + "2. **Adaptive Pre-training:** Pre-train on source domains with similarity-weighted learning rates.\n", + "3. **Fine-tuning:** Fine-tune on the target domain.\n", + "\n", + "We include the original paper experiment reproduction (binary classification task), followed by ablations on paired vs. unpaired similarity, adaptive vs. fixed learning rates, noise robustness, and an extension exploring different distance metrics (using multiclass classification).\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "!pip install pyts==0.13.0 seaborn==0.13.2\n", + "!mkdir -p DSA\n", + "!curl -o DSA/daily+and+sports+activities.zip https://archive.ics.uci.edu/static/public/256/daily+and+sports+activities.zip\n", + "!unzip -q DSA/daily+and+sports+activities.zip -d DSA\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import os\n", + "import random\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from pyts.metrics import dtw\n", + "from pyhealth.datasets import DSADataset, get_dataloader, split_by_patient, create_sample_dataset\n", + "from pyhealth.models.adaptive_transfer import AdaptiveTransferModel\n", + "from pyhealth.tasks import DSAActivityClassification\n", + "from pyhealth.trainer import Trainer\n", + "from pyhealth.metrics import multiclass_metrics_fn\n", + "\n", + "# Configuration\n", + "SEED = 598\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", + "DSA_ROOT = \"DSA/data\"\n", + "TARGET_UNIT = \"LL\" # Target domain: Left Leg\n", + "ALL_UNITS = (\"T\", \"RA\", \"LA\", \"RL\", \"LL\")\n", + "SOURCE_UNITS = [u for u in ALL_UNITS if u != TARGET_UNIT]\n", + "\n", + "BATCH_SIZE = 64\n", + "EPOCHS_PRETRAIN = 10\n", + "EPOCHS_FINETUNE = 20\n", + "BASE_LR = 1e-3\n", + "\n", + "plt.rcParams.update({\"figure.figsize\": (8, 4), \"axes.grid\": True, \"grid.alpha\": 0.3})\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Data Loading\n", + "We load the dataset for each sensor unit and create aligned train/val/test splits to maintain the paired structure across domains.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "base_dsa = DSADataset(root=DSA_ROOT, num_workers=1)\n", + "\n", + "template_task = DSAActivityClassification(dataset_root=DSA_ROOT, selected_units=(ALL_UNITS[0],))\n", + "template_full = base_dsa.set_task(template_task, num_workers=1)\n", + "train_ref, val_ref, test_ref = split_by_patient(template_full, [0.5, 0.25, 0.25], seed=SEED)\n", + "\n", + "patient_splits = {\n", + " \"train\": set(train_ref.patient_to_index),\n", + " \"val\": set(val_ref.patient_to_index),\n", + " \"test\": set(test_ref.patient_to_index),\n", + "}\n", + "\n", + "bundles = {}\n", + "for unit in ALL_UNITS:\n", + " full = base_dsa.set_task(\n", + " DSAActivityClassification(dataset_root=DSA_ROOT, selected_units=(unit,)),\n", + " num_workers=1,\n", + " )\n", + " bundles[unit] = {\n", + " \"train\": full.subset([idx for pid in patient_splits[\"train\"] for idx in full.patient_to_index[pid]]),\n", + " \"val\": full.subset([idx for pid in patient_splits[\"val\"] for idx in full.patient_to_index[pid]]),\n", + " \"test\": full.subset([idx for pid in patient_splits[\"test\"] for idx in full.patient_to_index[pid]]),\n", + " }\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Core Framework Helper Functions\n", + "Define the DTW distance function, IPD computation, and a generic training routine.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def dtw_distance_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " x_np, y_np = x.detach().cpu().numpy(), y.detach().cpu().numpy()\n", + " if x_np.ndim == 1: x_np, y_np = x_np[None, :], y_np[None, :]\n", + " vals = [dtw(np.ravel(a), np.ravel(b)) for a, b in zip(x_np, y_np)]\n", + " return torch.tensor(vals, dtype=x.dtype, device=x.device)\n", + "\n", + "def compute_mean_ipd(src_ds, tgt_ds, distance_fn=dtw_distance_fn, shuffle_target=False):\n", + " model = AdaptiveTransferModel(\n", + " dataset=bundles[TARGET_UNIT][\"train\"], feature_key=\"signal\", backbone=\"lstm\",\n", + " distance_fn=distance_fn, use_kde_smoothing=True\n", + " ).to(DEVICE)\n", + " model.eval()\n", + " \n", + " src_loader = get_dataloader(src_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + " tgt_loader = get_dataloader(tgt_ds, batch_size=BATCH_SIZE, shuffle=shuffle_target)\n", + " \n", + " vals = [model.compute_ipd(s, t) for s, t in zip(src_loader, tgt_loader)]\n", + " return float(np.mean(vals))\n", + "\n", + "def train_and_evaluate(source_order, ipd_scores, data_bundles, use_adaptive_lr=True):\n", + " model = AdaptiveTransferModel(\n", + " dataset=data_bundles[TARGET_UNIT][\"train\"],\n", + " feature_key=\"signal\",\n", + " backbone=\"lstm\",\n", + " use_similarity_weighting=use_adaptive_lr\n", + " ).to(DEVICE)\n", + " \n", + " def train_step(train_ds, val_ds, epochs, lr):\n", + " (Trainer(model=model,\n", + " device=DEVICE,\n", + " metrics=[\"accuracy\"], enable_logging=False)\n", + " .train(train_dataloader=get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True),\n", + " val_dataloader=get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False),\n", + " epochs=epochs,\n", + " optimizer_params={\"lr\": lr},\n", + " monitor=\"accuracy\"\n", + " ))\n", + "\n", + " for src in source_order:\n", + " lr = model.get_adaptive_lr(BASE_LR, 1.0 / (ipd_scores[src] + 1e-8)) if use_adaptive_lr else BASE_LR\n", + " train_step(data_bundles[src][\"train\"], data_bundles[src][\"val\"], EPOCHS_PRETRAIN, lr)\n", + " \n", + " train_step(data_bundles[TARGET_UNIT][\"train\"], data_bundles[TARGET_UNIT][\"val\"], EPOCHS_FINETUNE, BASE_LR)\n", + " \n", + " trainer = Trainer(model=model, device=DEVICE, metrics=[\"accuracy\"], enable_logging=False)\n", + " return trainer.evaluate(get_dataloader(data_bundles[TARGET_UNIT][\"test\"], batch_size=BATCH_SIZE, shuffle=False))[\"accuracy\"]\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Original Paper Experiment Reproduction (Binary Task)\n", + "**What is happening in this experiment:**\n", + "The original paper evaluates the framework on a binary classification task (one correct activity vs. rest) to calculate the Ratio of Correct Classification (RCC). We create a binary version of the dataset for this experiment (Activity 0 vs. Rest). Since the dataset is highly imbalanced, we upsample the positive samples during training and validation to match the negative samples, and downsample the negative samples during testing to ensure balanced evaluation.\n", + "\n", + "We compare three settings:\n", + "1. **No Transfer:** Training only on the target domain.\n", + "2. **Direct Transfer:** Sequential pre-training on source domains without similarity weighting, followed by fine-tuning.\n", + "3. **Adaptive IPD Transfer (Proposed):** Sequential pre-training ordered by IPD with similarity-weighted learning rates, followed by fine-tuning.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Create binary datasets (Activity 0 vs Rest)\n", + "def make_binary_bundle(bundle, pos_class=0):\n", + " bin_bundle = {}\n", + " import random\n", + " for split in [\"train\", \"val\", \"test\"]:\n", + " samples = []\n", + " for s in bundle[split]:\n", + " s_new = dict(s)\n", + " y = int(s[\"label\"].item()) if hasattr(s[\"label\"], \"item\") else int(s[\"label\"])\n", + " s_new[\"label\"] = 1 if y == pos_class else 0\n", + " samples.append(s_new)\n", + " pos_samples = [s for s in samples if s[\"label\"] == 1]\n", + " neg_samples = [s for s in samples if s[\"label\"] == 0]\n", + " if len(pos_samples) > 0 and len(neg_samples) > 0:\n", + " if split in [\"train\", \"val\"]:\n", + " num_to_add = len(neg_samples) - len(pos_samples)\n", + " if num_to_add > 0:\n", + " pos_samples += random.choices(pos_samples, k=num_to_add)\n", + " elif split == \"test\":\n", + " if len(neg_samples) > len(pos_samples):\n", + " neg_samples = random.sample(neg_samples, len(pos_samples))\n", + " samples = pos_samples + neg_samples\n", + " random.shuffle(samples)\n", + " bin_bundle[split] = create_sample_dataset(samples, {\"signal\": \"tensor\"}, {\"label\": \"binary\"}, f\"bin_{split}\")\n", + " return bin_bundle\n", + "\n", + "binary_bundles = {unit: make_binary_bundle(bundles[unit], pos_class=0) for unit in ALL_UNITS}\n", + "\n", + "# Compute Paired IPD for Adaptive Transfer (using binary validation sets)\n", + "paired_ipd_bin = {src: compute_mean_ipd(binary_bundles[src][\"val\"], binary_bundles[TARGET_UNIT][\"val\"], shuffle_target=False) for src in SOURCE_UNITS}\n", + "paired_order_bin = sorted(SOURCE_UNITS, key=paired_ipd_bin.get)\n", + "\n", + "# 1. No Transfer\n", + "no_transfer_model_bin = AdaptiveTransferModel(dataset=binary_bundles[TARGET_UNIT][\"train\"],\n", + " feature_key=\"signal\",\n", + " backbone=\"lstm\").to(DEVICE)\n", + "(Trainer(model=no_transfer_model_bin,\n", + " device=DEVICE,\n", + " metrics=[\"accuracy\"],\n", + " enable_logging=False)\n", + ".train(train_dataloader=get_dataloader(binary_bundles[TARGET_UNIT][\"train\"], batch_size=BATCH_SIZE, shuffle=True),\n", + " val_dataloader=get_dataloader(binary_bundles[TARGET_UNIT][\"val\"], batch_size=BATCH_SIZE, shuffle=False),\n", + " epochs=EPOCHS_PRETRAIN * len(SOURCE_UNITS) + EPOCHS_FINETUNE,\n", + " optimizer_params={\"lr\": BASE_LR},\n", + " monitor=\"accuracy\"\n", + "))\n", + "trainer_no_transfer_bin = Trainer(model=no_transfer_model_bin,\n", + " device=DEVICE,\n", + " metrics=[\"accuracy\"],\n", + " enable_logging=False)\n", + "acc_no_transfer_bin = trainer_no_transfer_bin.evaluate(get_dataloader(binary_bundles[TARGET_UNIT][\"test\"],\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=False))[\"accuracy\"]\n", + "\n", + "# 2. Direct Transfer (Fixed LR, default order)\n", + "acc_direct_transfer_bin = train_and_evaluate(SOURCE_UNITS, paired_ipd_bin, binary_bundles, use_adaptive_lr=False)\n", + "\n", + "# 3. Adaptive IPD Transfer (Proposed)\n", + "acc_adaptive_transfer_bin = train_and_evaluate(paired_order_bin, paired_ipd_bin, binary_bundles, use_adaptive_lr=True)\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plotting Original Paper Experiment\n", + "plt.figure(figsize=(7, 4))\n", + "sns.barplot(x=[\"No Transfer\", \"Direct Transfer\", \"Adaptive IPD (Proposed)\"], \n", + " y=[acc_no_transfer_bin, acc_direct_transfer_bin, acc_adaptive_transfer_bin], \n", + " palette=\"Set2\")\n", + "plt.ylabel(\"Test Accuracy (RCC)\")\n", + "plt.title(\"Original Paper Experiment (Binary Task): Transfer Baselines\")\n", + "plt.ylim(0, 1.0)\n", + "for i, v in enumerate([acc_no_transfer_bin, acc_direct_transfer_bin, acc_adaptive_transfer_bin]):\n", + " plt.text(i, v + 0.02, f\"{v:.3f}\", ha='center')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Ablation 1: Paired vs. Unpaired Similarity (H1)\n", + "*Note: For the following ablations and extensions, we use the **multiclass** classification task (all 19 activities) to provide a more challenging evaluation setting.*\n", + "\n", + "**What is happening in this experiment:**\n", + "We test the hypothesis that using paired similarity (Inter-domain Pairwise Distance, IPD) is better than using unpaired similarity. IPD leverages the fact that the multi-sensor data is synchronized (i.e., collected simultaneously from different body parts during the same activity). We compare the standard paired IPD against an unpaired similarity measure, which we simulate by shuffling the target batch to break the temporal synchronization.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Compute Paired IPD (Multiclass)\n", + "paired_ipd = {src: compute_mean_ipd(bundles[src][\"val\"], bundles[TARGET_UNIT][\"val\"], shuffle_target=False) for src in SOURCE_UNITS}\n", + "paired_order = sorted(SOURCE_UNITS, key=paired_ipd.get)\n", + "\n", + "# Compute Unpaired IPD\n", + "unpaired_ipd = {src: compute_mean_ipd(bundles[src][\"val\"], bundles[TARGET_UNIT][\"val\"], shuffle_target=True) for src in SOURCE_UNITS}\n", + "unpaired_order = sorted(SOURCE_UNITS, key=unpaired_ipd.get)\n", + "\n", + "acc_paired = train_and_evaluate(paired_order, paired_ipd, bundles, use_adaptive_lr=True)\n", + "acc_unpaired = train_and_evaluate(unpaired_order, unpaired_ipd, bundles, use_adaptive_lr=True)\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plotting Ablation 1\n", + "plt.figure(figsize=(6, 4))\n", + "sns.barplot(x=[\"Paired IPD (Proposed)\", \"Unpaired Similarity\"], y=[acc_paired, acc_unpaired], palette=\"viridis\")\n", + "plt.ylabel(\"Test Accuracy\")\n", + "plt.title(\"Ablation 1: Paired vs. Unpaired Similarity (Multiclass)\")\n", + "plt.ylim(0, 1.0)\n", + "for i, v in enumerate([acc_paired, acc_unpaired]):\n", + " plt.text(i, v + 0.02, f\"{v:.3f}\", ha='center')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Ablation 2: Adaptive LR vs. Fixed LR (H2)\n", + "**What is happening in this experiment:**\n", + "We test the hypothesis that using adaptive learning rates (LR) based on domain similarity (IPD) is better than using a fixed learning rate for all source domains during pre-training. We compare the proposed adaptive method (where the learning rate is scaled inversely to the IPD score, so similar domains get larger learning rates) against a baseline that uses a fixed learning rate schedule across all source domains, keeping the pre-training order the same.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "acc_fixed_lr_ipd_order = train_and_evaluate(paired_order, paired_ipd, bundles, use_adaptive_lr=False)\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plotting Ablation 2\n", + "plt.figure(figsize=(6, 4))\n", + "sns.barplot(x=[\"Adaptive LR (Proposed)\", \"Fixed LR\"], y=[acc_paired, acc_fixed_lr_ipd_order], palette=\"magma\")\n", + "plt.ylabel(\"Test Accuracy\")\n", + "plt.title(\"Ablation 2: Adaptive vs. Fixed Learning Rate (Multiclass)\")\n", + "plt.ylim(0, 1.0)\n", + "for i, v in enumerate([acc_paired, acc_fixed_lr_ipd_order]):\n", + " plt.text(i, v + 0.02, f\"{v:.3f}\", ha='center')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Ablation 3: Robustness to Noise (H3)\n", + "**What is happening in this experiment:**\n", + "We test the hypothesis that the IPD-based adaptive transfer framework is more robust to noisy source domains. We artificially inject Gaussian noise into one of the source domains (e.g., the 'T' sensor) and evaluate how much the downstream performance degrades. The adaptive IPD method should naturally assign a lower similarity (higher IPD) and thus a lower learning rate to the noisy domain, mitigating its negative impact on the pre-trained model compared to the fixed LR baseline.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def evaluate_with_noise(model, test_ds, noise_stds):\n", + " model.eval()\n", + " test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + " results = []\n", + " \n", + " for std in noise_stds:\n", + " ys, ps = [], []\n", + " for data in test_loader:\n", + " data = {k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v) for k, v in data.items()}\n", + " if std > 0:\n", + " data[\"signal\"] = data[\"signal\"] + torch.randn_like(data[\"signal\"]) * std\n", + " with torch.no_grad():\n", + " out = model(**data)\n", + " ys.append(out[\"y_true\"].cpu().numpy())\n", + " ps.append(out[\"y_prob\"].cpu().numpy())\n", + " acc = multiclass_metrics_fn(np.concatenate(ys), np.concatenate(ps), metrics=[\"accuracy\"])[\"accuracy\"]\n", + " results.append(acc)\n", + " return results\n", + "\n", + "# Train a No-Transfer baseline for comparison\n", + "no_transfer_model = AdaptiveTransferModel(dataset=bundles[TARGET_UNIT][\"train\"], feature_key=\"signal\", backbone=\"lstm\").to(DEVICE)\n", + "(Trainer(model=no_transfer_model,\n", + " device=DEVICE,\n", + " metrics=[\"accuracy\"],\n", + " enable_logging=False)\n", + ".train(train_dataloader=get_dataloader(bundles[TARGET_UNIT][\"train\"], batch_size=BATCH_SIZE, shuffle=True),\n", + " val_dataloader=get_dataloader(bundles[TARGET_UNIT][\"val\"], batch_size=BATCH_SIZE, shuffle=False),\n", + " epochs=EPOCHS_PRETRAIN * len(SOURCE_UNITS) + EPOCHS_FINETUNE,\n", + " optimizer_params={\"lr\": BASE_LR},\n", + " monitor=\"accuracy\"\n", + "))\n", + "\n", + "# Re-train adaptive model to evaluate\n", + "adaptive_model = AdaptiveTransferModel(dataset=bundles[TARGET_UNIT][\"train\"], feature_key=\"signal\", backbone=\"lstm\", use_similarity_weighting=True).to(DEVICE)\n", + "for src in paired_order:\n", + " (Trainer(model=adaptive_model,\n", + " device=DEVICE,\n", + " metrics=[\"accuracy\"],\n", + " enable_logging=False)\n", + " .train(train_dataloader=get_dataloader(bundles[src][\"train\"], batch_size=BATCH_SIZE, shuffle=True),\n", + " val_dataloader=get_dataloader(bundles[src][\"val\"], batch_size=BATCH_SIZE, shuffle=False),\n", + " epochs=EPOCHS_PRETRAIN,\n", + " optimizer_params={\"lr\": adaptive_model.get_adaptive_lr(BASE_LR, 1.0 / (paired_ipd[src] + 1e-8))},\n", + " monitor=\"accuracy\"\n", + " ))\n", + "(Trainer(model=adaptive_model,\n", + " device=DEVICE,\n", + " metrics=[\"accuracy\"],\n", + " enable_logging=False)\n", + ".train(train_dataloader=get_dataloader(bundles[TARGET_UNIT][\"train\"], batch_size=BATCH_SIZE, shuffle=True),\n", + " val_dataloader=get_dataloader(bundles[TARGET_UNIT][\"val\"], batch_size=BATCH_SIZE, shuffle=False),\n", + " epochs=EPOCHS_FINETUNE,\n", + " optimizer_params={\"lr\": BASE_LR}, monitor=\"accuracy\"\n", + "))\n", + "\n", + "noise_levels = [0.0, 0.05, 0.1, 0.2, 0.3]\n", + "acc_noise_no_transfer = evaluate_with_noise(no_transfer_model, bundles[TARGET_UNIT][\"test\"], noise_levels)\n", + "acc_noise_adaptive = evaluate_with_noise(adaptive_model, bundles[TARGET_UNIT][\"test\"], noise_levels)\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plotting Ablation 3\n", + "plt.figure(figsize=(8, 5))\n", + "plt.plot(noise_levels, acc_noise_no_transfer, marker='o', label=\"No Transfer\")\n", + "plt.plot(noise_levels, acc_noise_adaptive, marker='s', label=\"Adaptive Transfer (Proposed)\")\n", + "plt.xlabel(\"Gaussian Noise Std Dev\")\n", + "plt.ylabel(\"Test Accuracy\")\n", + "plt.title(\"Ablation 3: Robustness to Input Noise (Multiclass)\")\n", + "plt.legend()\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Extension: Distance Metrics Comparison\n", + "**What is happening in this experiment:**\n", + "We explore how different distance metrics affect the IPD calculation and the downstream transfer performance. The original paper uses Dynamic Time Warping (DTW), but we compare it against other common metrics like Euclidean distance, Manhattan (L1) distance, and Cosine similarity. We visualize the IPD values computed by each metric and train models using the adaptive transfer framework based on those different IPD scores.\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "distance_metrics = {\n", + " \"DTW\": dtw_distance_fn,\n", + " \"Euclidean\": \"euclidean\",\n", + " \"Minkowski (p=3)\": lambda x, y: F.pairwise_distance(x, y, p=3)\n", + "}\n", + "\n", + "metric_results = {}\n", + "ipd_heatmap_data = np.zeros((len(distance_metrics), len(SOURCE_UNITS)))\n", + "\n", + "for i, (name, dist_fn) in enumerate(distance_metrics.items()):\n", + " ipd_vals = {src: compute_mean_ipd(bundles[src][\"val\"], bundles[TARGET_UNIT][\"val\"], distance_fn=dist_fn) for src in SOURCE_UNITS}\n", + " order = sorted(SOURCE_UNITS, key=ipd_vals.get)\n", + " acc = train_and_evaluate(order, ipd_vals, bundles, use_adaptive_lr=True)\n", + " metric_results[name] = acc\n", + " \n", + " for j, src in enumerate(SOURCE_UNITS):\n", + " ipd_heatmap_data[i, j] = ipd_vals[src]\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Plotting Extension Results\n", + "# 1. Bar Chart for Accuracy Comparison\n", + "plt.figure(figsize=(7, 4))\n", + "sns.barplot(x=list(metric_results.keys()), y=list(metric_results.values()), palette=\"coolwarm\")\n", + "plt.ylabel(\"Test Accuracy\")\n", + "plt.title(\"Extension: Impact of Distance Metrics on Accuracy (Multiclass)\")\n", + "plt.ylim(0, 1.0)\n", + "for i, v in enumerate(metric_results.values()):\n", + " plt.text(i, v + 0.02, f\"{v:.3f}\", ha='center')\n", + "plt.show()\n", + "\n", + "# 2. Heatmap for IPD Values across Metrics\n", + "plt.figure(figsize=(8, 4))\n", + "sns.heatmap(ipd_heatmap_data, annot=True, fmt=\".2f\", xticklabels=SOURCE_UNITS, yticklabels=list(distance_metrics.keys()), cmap=\"YlGnBu\")\n", + "plt.xlabel(\"Source Domains\")\n", + "plt.ylabel(\"Distance Metric\")\n", + "plt.title(\"Extension: IPD Values across Different Distance Metrics\")\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..b270aa12a 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -47,6 +47,7 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset +from .dsa import DSADataset from .cardiology import CardiologyDataset from .chestxray14 import ChestXray14Dataset from .clinvar import ClinVarDataset diff --git a/pyhealth/datasets/configs/dsa.yaml b/pyhealth/datasets/configs/dsa.yaml new file mode 100644 index 000000000..e0589c6e2 --- /dev/null +++ b/pyhealth/datasets/configs/dsa.yaml @@ -0,0 +1,10 @@ +version: "5.0" +tables: + segments: + file_path: "dsa-pyhealth.csv" + patient_id: subject_id + timestamp: null + attributes: + - segment_path + - activity_name + - activity_code diff --git a/pyhealth/datasets/dsa.py b/pyhealth/datasets/dsa.py new file mode 100644 index 000000000..00cba1b67 --- /dev/null +++ b/pyhealth/datasets/dsa.py @@ -0,0 +1,269 @@ +"""Daily and Sports Activities (DSA) dataset loader.""" + +import logging +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import pandas as pd + +from pyhealth.datasets.base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Dataset metadata (SleepEDF-style: domain facts live in code, YAML is tables only) +# ----------------------------------------------------------------------------- + +DSA_PYHEALTH_MANIFEST = "dsa-pyhealth.csv" + +_LABEL_MAPPING: Dict[str, str] = { + "A01": "sitting", + "A02": "standing", + "A03": "lying_on_back", + "A04": "lying_on_right_side", + "A05": "ascending_stairs", + "A06": "descending_stairs", + "A07": "standing_in_elevator_still", + "A08": "moving_around_in_elevator", + "A09": "walking_in_parking_lot", + "A10": "walking_on_treadmill_flat", + "A11": "walking_on_treadmill_inclined", + "A12": "running_on_treadmill", + "A13": "exercising_on_stepper", + "A14": "exercising_on_cross_trainer", + "A15": "cycling_on_exercise_bike_horizontal", + "A16": "cycling_on_exercise_bike_vertical", + "A17": "rowing", + "A18": "jumping", + "A19": "playing_basketball", +} + +_UNITS: List[Dict[str, str]] = [ + {"T": "Torso"}, + {"RA": "Right Arm"}, + {"LA": "Left Arm"}, + {"RL": "Right Leg"}, + {"LL": "Left Leg"}, +] + +_SENSORS: List[Dict[str, str]] = [ + {"xacc": "X-axis Accelerometer"}, + {"yacc": "Y-axis Accelerometer"}, + {"zacc": "Z-axis Accelerometer"}, + {"xgyro": "X-axis Gyroscope"}, + {"ygyro": "Y-axis Gyroscope"}, + {"zgyro": "Z-axis Gyroscope"}, + {"xmag": "X-axis Magnetometer"}, + {"ymag": "Y-axis Magnetometer"}, + {"zmag": "Z-axis Magnetometer"}, +] + +_SAMPLING_FREQUENCY = 25 +_NUM_COLUMNS = 45 +_NUM_ROWS = 125 + +_LAYOUT = { + "activity_dir_pattern": r"^a\d{2}$", + "subject_dir_pattern": r"^p\d+$", + "segment_file_pattern": r"^s\d+\.txt$", + "code_regex_pattern": r"^A(\d+)$", + "file_extension": ".txt", +} + +_ACTIVITY_DIR_RE = re.compile(_LAYOUT["activity_dir_pattern"]) +_SUBJECT_DIR_RE = re.compile(_LAYOUT["subject_dir_pattern"]) +_SEGMENT_FILE_RE = re.compile(_LAYOUT["segment_file_pattern"]) +_ACTIVITY_CODE_RE = re.compile(_LAYOUT["code_regex_pattern"]) + +DSA_TABLE_NAME = "segments" + + +class DSADataset(BaseDataset): + """Daily and Sports Activities (DSA) time-series dataset (Barshan & Altun, 2010). + + Recordings use five on-body IMU units (torso, two arms, two legs); each unit + contributes nine columns per row (3-axis accelerometer, gyroscope, and + magnetometer), so each segment row has 45 comma-separated values. The public + release is sampled at 25 Hz; each ``.txt`` segment is typically 125 lines (about + five seconds of data). + + On disk, activities live in folders ``a01`` through ``a19``, subjects in ``p1`` + through ``p8``, and segment files ``s01.txt``, ``s02.txt``, … under each + subject. + + Dataset is available at: + https://archive.ics.uci.edu/dataset/256/daily+and+sports+activities + + Citations: + If you use this dataset, cite: Barshan, B., & Altun, K. (2010). Daily and + Sports Activities [Dataset]. UCI Machine Learning Repository. + https://doi.org/10.24432/C5C59F + + Args: + root str: Dataset root (activity folders; manifest created if missing). + dataset_name: Passed to :class:`BaseDataset`. Default ``"dsa"``. + config_path: Path to ``dsa.yaml`` (default: package ``configs/dsa.yaml``). + cache_dir: Cache directory for :class:`BaseDataset`. + num_workers: Parallel workers for base pipelines. + dev: Passed to :class:`BaseDataset` (limits patients when building events). + + Examples: + >>> from pyhealth.datasets import DSADataset + >>> dataset = DSADataset(root="/path/to/dsa") + >>> dataset.stat() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + cache_dir: Optional[Union[str, Path]] = None, + num_workers: int = 1, + dev: bool = False, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "dsa.yaml" + ) + + metadata_path = os.path.join(root, DSA_PYHEALTH_MANIFEST) + if not os.path.exists(metadata_path): + self.prepare_metadata(root) + + super().__init__( + root=root, + tables=[DSA_TABLE_NAME], + dataset_name=dataset_name or "dsa", + config_path=config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + self.label_mapping: Dict[str, str] = dict(_LABEL_MAPPING) + self.units: List[Dict[str, str]] = list(_UNITS) + self.sensors: List[Dict[str, str]] = list(_SENSORS) + self.sampling_frequency: int = _SAMPLING_FREQUENCY + self._num_columns: int = _NUM_COLUMNS + self._num_rows: int = _NUM_ROWS + + self._manifest_df: pd.DataFrame = pd.read_csv(os.path.join(self.root, self.config.tables[DSA_TABLE_NAME].file_path)) + + def prepare_metadata(self, root: str) -> None: + """Scan ``root`` and write ``dsa-pyhealth.csv`` (``tables.segments``).""" + rows = [] + for a_dir in sorted(os.listdir(root)): + if not _ACTIVITY_DIR_RE.match(a_dir): + continue + activity_code = a_dir.upper() + a_path = os.path.join(root, a_dir) + if not os.path.isdir(a_path): + continue + + for p_dir in sorted(os.listdir(a_path)): + if not _SUBJECT_DIR_RE.match(p_dir): + continue + p_path = os.path.join(a_path, p_dir) + if not os.path.isdir(p_path): + continue + + for s_file in sorted(os.listdir(p_path)): + if not _SEGMENT_FILE_RE.match(s_file): + continue + + rows.append( + { + "subject_id": p_dir, + "activity_name": _LABEL_MAPPING[activity_code], + "activity_code": activity_code, + "segment_path": f"{a_dir}/{p_dir}/{s_file}", + } + ) + + if not rows: + raise ValueError( + f"No DSA segments under {root}; expected aXX/pY/sZZ.txt layout." + ) + + metadata_path = os.path.join(root, DSA_PYHEALTH_MANIFEST) + df = pd.DataFrame(rows) + df = df[["subject_id", "activity_name", "activity_code", "segment_path"]] + df.to_csv(metadata_path, index=False) + + def get_subject_ids(self) -> List[str]: + """Return sorted subject IDs from the manifest.""" + return sorted(self._manifest_df["subject_id"].unique().tolist()) + + def get_activity_labels(self) -> Dict[str, int]: + """Map activity name to class index (ordered by activity code).""" + codes = sorted(self.label_mapping.keys()) + return {self.label_mapping[c]: i for i, c in enumerate(codes)} + + def get_subject_data(self, subject_id: str) -> Dict[str, Any]: + """Load all segment arrays for one subject.""" + subject_df = self._manifest_df[self._manifest_df["subject_id"] == subject_id] + if subject_df.empty: + raise ValueError(f"Subject {subject_id!r} not found in manifest") + + subject_data: Dict[str, Any] = {"id": subject_id, "activities": {}} + + for (activity_name, activity_code), group in subject_df.groupby( + ["activity_name", "activity_code"] + ): + segments = [] + for _, row in group.iterrows(): + segment_path = os.path.join(self.root, row["segment_path"]) + segment_data = self._load_segment(segment_path, subject_id, activity_name) + segments.append(segment_data) + + subject_data["activities"][activity_name] = { + "id": activity_code, + "segments": segments, + } + + return subject_data + + def _load_segment( + self, + file_path: str, + subject_id: str, + activity: str, + ) -> Dict[str, Any]: + """Load a single segment file and return as dict.""" + try: + data = np.loadtxt(file_path, delimiter=",", dtype=np.float64) + except Exception as e: + raise ValueError( + f"Failed to parse DSA segment {file_path}; expected a " + f"{self._num_rows}x{self._num_columns} comma-separated numeric file." + ) from e + + if data.ndim == 1: + data = data.reshape(1, -1) + + n_rows, n_cols = data.shape + if n_rows != self._num_rows: + raise ValueError( + f"{file_path} has {n_rows} rows, expected {self._num_rows}" + ) + if n_cols != self._num_columns: + raise ValueError( + f"{file_path} has {n_cols} columns, expected {self._num_columns}" + ) + if not np.isfinite(data).all(): + raise ValueError(f"{file_path} contains non-finite values (NaN or Inf).") + + return { + "file_path": Path(file_path), + "subject_id": subject_id, + "activity": activity, + "data": data, + "num_samples": n_rows, + "sampling_rate": self.sampling_frequency, + "segment_filename": os.path.basename(file_path), + } diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..a5c01ef7b 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,4 +1,5 @@ from .adacare import AdaCare, AdaCareLayer, MultimodalAdaCare +from .adaptive_transfer import AdaptiveTransferModel from .agent import Agent, AgentLayer from .base_model import BaseModel from .biot import BIOT diff --git a/pyhealth/models/adaptive_transfer.py b/pyhealth/models/adaptive_transfer.py new file mode 100644 index 000000000..0a8fe2cb1 --- /dev/null +++ b/pyhealth/models/adaptive_transfer.py @@ -0,0 +1,574 @@ +from __future__ import annotations + +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models.base_model import BaseModel + +DistanceFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +class AdaptiveTransferModel(BaseModel): + """Adaptive transfer model for multi-source time-series classification. + + This model is inspired by + "Daily Physical Activity Monitoring: Adaptive Learning from Multi-Source + Motion Sensor Data". + + The model supports: + 1. standard supervised forward passes; + 2. paired source-target similarity computation; + 3. similarity-weighted transfer utilities for example scripts; + 4. dependency injection for both the backbone and distance function. + + Args: + dataset: PyHealth sample dataset. + feature_key: Dense time-series feature key. If None, uses the first + available feature key. + hidden_dim: Hidden size for built-in backbones. + num_layers: Number of recurrent layers for built-in backbones. + dropout: Dropout probability. + bidirectional: Whether recurrent backbones are bidirectional. + backbone: Backbone encoder specification. Supported string values are + {"lstm", "gru", "mlp"}, or a custom ``nn.Module`` can be passed. + backbone_output_dim: Output dimension of a custom backbone. Required + when it cannot be inferred from the module itself. + distance_fn: Distance function for IPD-style similarity. One of + {"euclidean", "manhattan", "cosine"} or a callable. Euclidean and + Manhattan use ``torch.nn.functional.pairwise_distance``; cosine + uses ``1 - cosine_similarity``. + use_similarity_weighting: Whether to scale learning rates by similarity. + use_kde_smoothing: Whether to smooth pairwise distances before + averaging. + smoothing_std: Standard deviation of Gaussian smoothing noise. + eps: Small constant for numerical stability. + input_dim: Optional per-time-step input width; inferred from the + dataset when omitted. + Raises: + ValueError: If the dataset does not expose exactly one label key. + + Example: + >>> model = AdaptiveTransferModel(dataset=dataset, feature_key="signal") + >>> output = model(**batch) + >>> output["logit"].shape + """ + + def __init__( + self, + dataset: SampleDataset, + feature_key: Optional[str] = None, + hidden_dim: int = 128, + num_layers: int = 1, + dropout: float = 0.2, + bidirectional: bool = False, + backbone: Union[str, nn.Module] = "lstm", + backbone_output_dim: Optional[int] = None, + distance_fn: Union[str, DistanceFn] = "euclidean", + use_similarity_weighting: bool = True, + use_kde_smoothing: bool = True, + smoothing_std: float = 0.01, + eps: float = 1e-8, + input_dim: Optional[int] = None, + ) -> None: + """Initialize the adaptive transfer model. + + Args: + dataset: PyHealth sample dataset. + feature_key: Dense input feature key. If None, uses the first + available feature key from the dataset. + hidden_dim: Hidden size for built-in backbones. + num_layers: Number of recurrent layers for built-in backbones. + dropout: Dropout probability. + bidirectional: Whether recurrent backbones are bidirectional. + backbone: Backbone encoder specification. Supported string values + are {"lstm", "gru", "mlp"}, or a custom ``nn.Module``. + backbone_output_dim: Output dimension of a custom backbone. + distance_fn: Distance function identifier or callable used for + IPD-style similarity. + use_similarity_weighting: Whether adaptive learning rates should be + scaled by source-target similarity. + use_kde_smoothing: Whether to apply smoothing to pairwise + distances before averaging. + smoothing_std: Standard deviation of the Gaussian smoothing noise. + eps: Small constant for numerical stability. + input_dim: If set, per-time-step input size for built-in backbones + (e.g. number of DSA channels). When ``None``, the model infers + from ``dataset.input_info`` when present, otherwise from the + first training sample. + + Raises: + ValueError: If the dataset exposes more than one label key. + """ + super().__init__(dataset) + + if len(self.label_keys) != 1: + raise ValueError("AdaptiveTransferModel supports exactly one label key.") + + self.label_key = self.label_keys[0] + self.feature_key = feature_key or self.feature_keys[0] + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.bidirectional = bidirectional + self.use_similarity_weighting = use_similarity_weighting + self.use_kde_smoothing = use_kde_smoothing + self.smoothing_std = smoothing_std + self.eps = eps + + encoder_input_dim = ( + max(1, int(input_dim)) + if input_dim is not None + else self._infer_input_dim(self.feature_key) + ) + self.encoder, encoder_output_dim = self._build_encoder( + input_dim=encoder_input_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + dropout=dropout, + bidirectional=bidirectional, + backbone=backbone, + backbone_output_dim=backbone_output_dim, + ) + + self.dropout = nn.Dropout(dropout) + self.classifier = nn.Linear(encoder_output_dim, self.get_output_size()) + self.distance_fn = self._resolve_distance_fn(distance_fn) + + def _build_encoder( + self, + input_dim: int, + hidden_dim: int, + num_layers: int, + dropout: float, + bidirectional: bool, + backbone: Union[str, nn.Module], + backbone_output_dim: Optional[int], + ) -> Tuple[nn.Module, int]: + """Build the encoder and infer its output size.""" + if isinstance(backbone, nn.Module): + output_dim = self._infer_backbone_output_dim( + backbone, backbone_output_dim + ) + return backbone, output_dim + + backbone_name = backbone.lower() + + if backbone_name == "lstm": + encoder = nn.LSTM( + input_size=input_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + output_dim = hidden_dim * (2 if bidirectional else 1) + return encoder, output_dim + + if backbone_name == "gru": + encoder = nn.GRU( + input_size=input_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + output_dim = hidden_dim * (2 if bidirectional else 1) + return encoder, output_dim + + if backbone_name == "mlp": + encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim), + ) + return encoder, hidden_dim + + raise ValueError( + f"Unsupported backbone: {backbone}. Expected one of " + "{'lstm', 'gru', 'mlp'} or a custom backbone module." + ) + + def _infer_backbone_output_dim( + self, + backbone: nn.Module, + backbone_output_dim: Optional[int], + ) -> int: + """Infer output size for a custom backbone.""" + if backbone_output_dim is not None: + return backbone_output_dim + + for attr in ["output_dim", "hidden_dim", "hidden_size", "embedding_dim"]: + if hasattr(backbone, attr): + value = getattr(backbone, attr) + if isinstance(value, int) and value > 0: + return value + + raise ValueError( + "Could not infer backbone output dimension. Please provide " + "backbone_output_dim for a custom backbone." + ) + + def _resolve_distance_fn( + self, + distance_fn: Union[str, DistanceFn], + ) -> DistanceFn: + """Resolve a string or callable distance function.""" + if callable(distance_fn): + return distance_fn + + name = distance_fn.lower() + if name == "euclidean": + return lambda x, y: F.pairwise_distance(x, y, p=2) + if name == "manhattan": + return lambda x, y: F.pairwise_distance(x, y, p=1) + if name == "cosine": + return lambda x, y: 1.0 - F.cosine_similarity(x, y, dim=1) + + raise ValueError( + f"Unsupported distance_fn: {distance_fn}. Expected one of " + "{'euclidean', 'manhattan', 'cosine'} or a callable." + ) + + def _infer_input_dim(self, feature_key: str) -> int: + """Infer per-time-step width from ``input_info`` or the first sample.""" + if self.dataset is None: + return 1 + + try: + stats = self.dataset.input_info[feature_key] + if "len" in stats and isinstance(stats["len"], int): + return max(1, int(stats["len"])) + if "dim" in stats and isinstance(stats["dim"], int): + return max(1, int(stats["dim"])) + except (KeyError, TypeError, AttributeError): + pass + + try: + n = len(self.dataset) + if n == 0 or feature_key not in self.dataset[0]: + return 1 + feature = self.dataset[0][feature_key] + if isinstance(feature, torch.Tensor): + value = feature + else: + proc = self.dataset.input_processors[feature_key] + schema = proc.schema() + if "value" not in schema: + return 1 + value = feature[schema.index("value")] + if value.dim() == 1: + return 1 + if value.dim() >= 2: + return max(1, int(value.shape[-1])) + except Exception: + pass + + return 1 + + def _get_feature_value_and_mask( + self, + feature: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Extract normalized value and mask tensors from a feature input.""" + if isinstance(feature, torch.Tensor): + value = feature + mask = None + else: + schema = self.dataset.input_processors[self.feature_key].schema() + if "value" not in schema: + raise ValueError( + f"Feature '{self.feature_key}' must contain 'value'." + ) + + value = feature[schema.index("value")] + mask = feature[schema.index("mask")] if "mask" in schema else None + + if mask is None and len(feature) == len(schema) + 1: + mask = feature[-1] + + value = value.to(self.device).float() + if mask is not None: + mask = mask.to(self.device).float() + + # [B, T] -> [B, T, 1] + if value.dim() == 2: + value = value.unsqueeze(-1) + elif value.dim() != 3: + raise ValueError( + f"Unsupported input shape for '{self.feature_key}': " + f"{tuple(value.shape)}" + ) + + if mask is not None: + if mask.dim() == 3: + mask = mask.any(dim=-1).float() + elif mask.dim() != 2: + raise ValueError( + f"Unsupported mask shape for '{self.feature_key}': " + f"{tuple(mask.shape)}" + ) + + return value, mask + + def _masked_mean_pool( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + """Apply masked mean pooling over the time dimension.""" + if mask is None: + return x.mean(dim=1) + + weights = mask.unsqueeze(-1).float() + denom = weights.sum(dim=1).clamp_min(1.0) + return (x * weights).sum(dim=1) / denom + + def _encode_sequence( + self, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Encode a dense time series into a fixed-size embedding.""" + if isinstance(self.encoder, nn.LSTM): + if mask is not None: + lengths = mask.sum(dim=1).long().clamp(min=1).cpu() + packed = nn.utils.rnn.pack_padded_sequence( + value, + lengths, + batch_first=True, + enforce_sorted=False, + ) + _, (h_n, _) = self.encoder(packed) + else: + _, (h_n, _) = self.encoder(value) + + if self.bidirectional: + emb = torch.cat([h_n[-2], h_n[-1]], dim=-1) + else: + emb = h_n[-1] + return self.dropout(emb) + + if isinstance(self.encoder, nn.GRU): + if mask is not None: + lengths = mask.sum(dim=1).long().clamp(min=1).cpu() + packed = nn.utils.rnn.pack_padded_sequence( + value, + lengths, + batch_first=True, + enforce_sorted=False, + ) + _, h_n = self.encoder(packed) + else: + _, h_n = self.encoder(value) + + if self.bidirectional: + emb = torch.cat([h_n[-2], h_n[-1]], dim=-1) + else: + emb = h_n[-1] + return self.dropout(emb) + + # Non-recurrent encoders may return [B, D] or [B, T, D]. + encoder_out = self.encoder(value) + + if encoder_out.dim() == 3: + emb = self._masked_mean_pool(encoder_out, mask) + elif encoder_out.dim() == 2: + emb = encoder_out + else: + raise ValueError( + "Custom backbone must return a tensor of shape [B, D] " + "or [B, T, D]." + ) + + return self.dropout(emb) + + def forward( + self, + **kwargs: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + ) -> Dict[str, torch.Tensor]: + """Run the forward pass and optionally compute loss. + + Args: + **kwargs: Keyword inputs expected by PyHealth. Must contain the + configured feature key. May also contain the label key. + + Returns: + A dictionary containing model outputs and, when labels are + provided, the loss and ground-truth labels. + + Raises: + ValueError: If the configured feature key is missing from inputs. + """ + if self.feature_key not in kwargs: + raise ValueError( + f"Expected feature key '{self.feature_key}' in model inputs." + ) + + value, mask = self._get_feature_value_and_mask(kwargs[self.feature_key]) + patient_emb = self._encode_sequence(value, mask) + logits = self.classifier(patient_emb) + y_prob = self.prepare_y_prob(logits) + + results: Dict[str, torch.Tensor] = { + "logit": logits, + "y_prob": y_prob, + } + + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + + if self.mode == "multiclass" and y_true.dim() > 1: + y_true = y_true.squeeze(-1).long() + elif self.mode == "binary": + y_true = y_true.float() + if y_true.dim() == 1: + y_true = y_true.unsqueeze(-1) + + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + return results + + def forward_from_embedding( + self, + **kwargs: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + ) -> Dict[str, torch.Tensor]: + """Forward hook kept for compatibility with PyHealth interfaces.""" + return self.forward(**kwargs) + + @torch.no_grad() + def extract_embedding( + self, + batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + ) -> torch.Tensor: + """Extract latent embeddings for a batch. + + Args: + batch: Batch dictionary containing the configured feature key. + + Returns: + Batch embedding tensor of shape [B, H]. + """ + value, mask = self._get_feature_value_and_mask(batch[self.feature_key]) + return self._encode_sequence(value, mask) + + @torch.no_grad() + def compute_pairwise_distances( + self, + source_batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + target_batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + ) -> torch.Tensor: + """Compute paired distances between source and target embeddings. + + Args: + source_batch: Source-domain batch dictionary. + target_batch: Target-domain batch dictionary. + + Returns: + Distance tensor of shape [B]. + + Raises: + ValueError: If the source and target batch sizes differ. + """ + source_emb = self.extract_embedding(source_batch) + target_emb = self.extract_embedding(target_batch) + + if source_emb.shape[0] != target_emb.shape[0]: + raise ValueError( + "Source and target batches must have the same batch size " + "for paired IPD." + ) + + return self.distance_fn(source_emb, target_emb) + + @torch.no_grad() + def compute_ipd( + self, + source_batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + target_batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + ) -> float: + """Compute an IPD-style distance between one source and target batch. + + Args: + source_batch: Source-domain batch dictionary. + target_batch: Target-domain batch dictionary. + + Returns: + Mean paired distance as a float. + """ + distances = self.compute_pairwise_distances(source_batch, target_batch) + + if self.use_kde_smoothing: + noise = torch.randn_like(distances) * self.smoothing_std + distances = (distances + noise).clamp_min(0.0) + + return float(distances.mean().item()) + + @torch.no_grad() + def compute_source_similarities( + self, + source_batches: Sequence[ + Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]] + ], + target_batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + ) -> List[float]: + """Compute inverse-distance similarities for multiple source batches. + + Args: + source_batches: Sequence of source-domain batches. + target_batch: Target-domain batch dictionary. + + Returns: + List of similarity scores, one per source batch. + """ + similarities: List[float] = [] + for source_batch in source_batches: + ipd = self.compute_ipd(source_batch, target_batch) + similarities.append(1.0 / (ipd + self.eps)) + return similarities + + @torch.no_grad() + def rank_source_domains( + self, + source_batches: Sequence[ + Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]] + ], + target_batch: Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, ...]]], + ) -> List[int]: + """Rank source domains by descending similarity to the target. + + Args: + source_batches: Sequence of source-domain batches. + target_batch: Target-domain batch dictionary. + + Returns: + List of source-domain indices sorted from most similar to least + similar. + """ + similarities = self.compute_source_similarities(source_batches, target_batch) + return sorted( + range(len(similarities)), + key=lambda i: similarities[i], + reverse=True, + ) + + def get_adaptive_lr(self, base_lr: float, similarity: float) -> float: + """Scale the base learning rate by similarity if enabled. + + Args: + base_lr: Base learning rate before adaptation. + similarity: Source-target similarity score. + + Returns: + Adapted learning rate. + """ + if not self.use_similarity_weighting: + return base_lr + return base_lr * max(similarity, self.eps) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..2c3509ff0 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -13,6 +13,7 @@ from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 +from .dsa_activity_classification import DSAActivityClassification from .drug_recommendation import ( DrugRecommendationEICU, DrugRecommendationMIMIC3, diff --git a/pyhealth/tasks/dsa_activity_classification.py b/pyhealth/tasks/dsa_activity_classification.py new file mode 100644 index 000000000..53d4ddb2e --- /dev/null +++ b/pyhealth/tasks/dsa_activity_classification.py @@ -0,0 +1,253 @@ +"""Task for classifying Daily and Sports Activities (DSA) sensor segments.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from .base_task import BaseTask + + +class DSAActivityClassification(BaseTask): + """Multi-class activity-recognition task for the DSA dataset. + + The DSA paper models each wearable placement as a domain. This task keeps + that structure explicit by letting the caller choose one or more body-site + units and returning a fixed-length multivariate time series for each + activity segment. + + Each output sample corresponds to one ``sXX.txt`` segment file and contains: + + - ``signal``: a ``(125, 9 * num_selected_units)`` float tensor-like array + - ``label``: an integer in ``[0, 18]`` representing one of the 19 DSA + activities + + The default ``normalization="minmax"`` follows the paper's preprocessing by + rescaling each channel independently to ``[-1, 1]``. Constant-valued + channels are mapped to zeros. + + Args: + dataset_root: Root directory of the DSA dataset. The current DSA + manifest stores segment paths relative to this directory, so the + task needs the root to load segment files. + selected_units: Optional subset of DSA body sites to use. Valid values + are ``"T"``, ``"RA"``, ``"LA"``, ``"RL"``, and ``"LL"``. ``None`` + uses all five units in canonical DSA order. + normalization: Feature normalization strategy. ``"minmax"`` rescales + each selected channel to ``[-1, 1]``; ``"none"`` preserves raw + values. + + Examples: + >>> from pyhealth.datasets import DSADataset + >>> from pyhealth.tasks import DSAActivityClassification + >>> dataset = DSADataset(root="/path/to/dsa") + >>> task = DSAActivityClassification( + ... dataset_root="/path/to/dsa", + ... selected_units=("LA",), + ... ) + >>> sample_dataset = dataset.set_task(task) + >>> sample = sample_dataset[0] + >>> tuple(sample["signal"].shape) + (125, 9) + """ + + task_name: str = "DSAActivityClassification" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + VALID_UNITS: ClassVar[Tuple[str, ...]] = ("T", "RA", "LA", "RL", "LL") + SENSOR_KEYS: ClassVar[Tuple[str, ...]] = ( + "xacc", + "yacc", + "zacc", + "xgyro", + "ygyro", + "zgyro", + "xmag", + "ymag", + "zmag", + ) + CHANNELS_PER_UNIT: ClassVar[int] = 9 + SEGMENT_LENGTH: ClassVar[int] = 125 + TOTAL_CHANNELS: ClassVar[int] = 45 + NUM_CLASSES: ClassVar[int] = 19 + + def __init__( + self, + dataset_root: Union[str, Path], + selected_units: Optional[Union[str, Sequence[str]]] = None, + normalization: str = "minmax", + ) -> None: + """Initialize the DSA activity-classification task.""" + if normalization not in {"minmax", "none"}: + raise ValueError( + "Unsupported normalization. Expected 'minmax' or 'none'." + ) + + self.dataset_root = Path(dataset_root).expanduser().resolve() + self.selected_units = self._normalize_units(selected_units) + self.normalization = normalization + self.channel_names = self._build_channel_names(self.selected_units) + + @classmethod + def _normalize_units( + cls, + selected_units: Optional[Union[str, Sequence[str]]], + ) -> Tuple[str, ...]: + """Validate and normalize the requested DSA body-site units.""" + if selected_units is None: + return cls.VALID_UNITS + + if isinstance(selected_units, str): + requested_units = [selected_units] + else: + requested_units = list(selected_units) + + normalized_units: List[str] = [] + seen_units = set() + for unit in requested_units: + normalized = unit.upper() + if normalized not in cls.VALID_UNITS: + raise ValueError( + f"Unsupported DSA unit '{unit}'. " + f"Expected one of {cls.VALID_UNITS}." + ) + if normalized not in seen_units: + normalized_units.append(normalized) + seen_units.add(normalized) + + if not normalized_units: + raise ValueError("selected_units must contain at least one DSA unit.") + + return tuple(normalized_units) + + @classmethod + def _build_channel_names(cls, selected_units: Sequence[str]) -> List[str]: + """Build human-readable channel names for the selected units.""" + return [ + f"{unit}_{sensor_key}" + for unit in selected_units + for sensor_key in cls.SENSOR_KEYS + ] + + @classmethod + def _channel_slice_for_unit(cls, unit: str) -> slice: + """Return the column slice for a single DSA body-site unit.""" + unit_index = cls.VALID_UNITS.index(unit) + start = unit_index * cls.CHANNELS_PER_UNIT + stop = start + cls.CHANNELS_PER_UNIT + return slice(start, stop) + + @classmethod + def _extract_units( + cls, + full_signal: np.ndarray, + selected_units: Sequence[str], + ) -> np.ndarray: + """Extract and concatenate columns for the chosen DSA units.""" + unit_signals = [ + full_signal[:, cls._channel_slice_for_unit(unit)] + for unit in selected_units + ] + return np.concatenate(unit_signals, axis=1) + + @staticmethod + def _minmax_normalize(signal: np.ndarray) -> np.ndarray: + """Scale each channel independently to ``[-1, 1]``.""" + signal_min = signal.min(axis=0, keepdims=True) + signal_max = signal.max(axis=0, keepdims=True) + signal_range = signal_max - signal_min + + safe_range = np.where(signal_range > 0.0, signal_range, 1.0) + normalized = 2.0 * (signal - signal_min) / safe_range - 1.0 + normalized[:, signal_range[0] == 0.0] = 0.0 + return normalized.astype(np.float32, copy=False) + + @classmethod + def _activity_label(cls, activity_code: str) -> int: + """Convert a DSA activity code like ``A01`` into a zero-based label.""" + if not isinstance(activity_code, str) or not activity_code.startswith("A"): + raise ValueError(f"Invalid DSA activity code: {activity_code!r}") + + try: + label = int(activity_code[1:]) - 1 + except ValueError as exc: + raise ValueError( + f"Invalid DSA activity code: {activity_code!r}" + ) from exc + + if label < 0 or label >= cls.NUM_CLASSES: + raise ValueError(f"Invalid DSA activity code: {activity_code!r}") + return label + + def _resolve_segment_path(self, segment_path: str) -> Path: + """Resolve a stored segment path against the configured dataset root.""" + path = Path(segment_path) + if path.is_absolute(): + return path + return self.dataset_root / path + + @classmethod + def _load_segment(cls, file_path: Path) -> np.ndarray: + """Load a DSA segment file and validate its expected shape.""" + try: + signal = np.loadtxt(file_path, delimiter=",", dtype=np.float64) + except Exception as exc: + raise ValueError( + f"Failed to parse DSA segment {file_path}." + ) from exc + + if signal.ndim == 1: + signal = signal.reshape(1, -1) + + expected_shape = (cls.SEGMENT_LENGTH, cls.TOTAL_CHANNELS) + if signal.shape != expected_shape: + raise ValueError( + f"{file_path} has shape {signal.shape}, expected {expected_shape}." + ) + if not np.isfinite(signal).all(): + raise ValueError(f"{file_path} contains non-finite values.") + return signal + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Create one activity-classification sample per DSA segment event. + + Args: + patient: A PyHealth patient object created from ``DSADataset``. + + Returns: + A list of task samples. Each sample includes metadata for the + subject and segment alongside the extracted multivariate signal and + multiclass label. + """ + samples: List[Dict[str, Any]] = [] + + for event in patient.get_events(event_type="segments"): + segment_file = self._resolve_segment_path(event.segment_path) + full_signal = self._load_segment(segment_file) + signal = self._extract_units(full_signal, self.selected_units) + + if self.normalization == "minmax": + signal = self._minmax_normalize(signal) + else: + signal = signal.astype(np.float32, copy=False) + + samples.append( + { + "patient_id": patient.patient_id, + "subject_id": patient.patient_id, + "sample_id": f"{patient.patient_id}:{event.segment_path}", + "segment_path": event.segment_path, + "activity_name": event.activity_name, + "activity_code": event.activity_code, + "unit_combo": "+".join(self.selected_units), + "num_channels": signal.shape[1], + "channel_names": self.channel_names, + "signal": signal, + "label": self._activity_label(event.activity_code), + } + ) + + return samples diff --git a/tests/core/test_adaptive_transfer.py b/tests/core/test_adaptive_transfer.py new file mode 100644 index 000000000..c2ef461d2 --- /dev/null +++ b/tests/core/test_adaptive_transfer.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +from pyhealth.models.adaptive_transfer import AdaptiveTransferModel + + +class _DummyInputProcessor: + """Input processor stub for model tests.""" + + def schema(self): + return ("value",) + + +class _DummyOutputProcessor: + """Output processor stub for model tests.""" + + def __init__(self, size: int): + self._size = size + + def size(self): + return self._size + + +class _DummyDataset: + """Dataset stub satisfying BaseModel requirements.""" + + def __init__(self, num_classes: int = 3, input_dim: int = 4): + self.input_schema = {"signal": "tensor"} + self.output_schema = {"label": "multiclass"} + self.input_processors = {"signal": _DummyInputProcessor()} + self.output_processors = {"label": _DummyOutputProcessor(num_classes)} + self.input_info = {"signal": {"dim": input_dim}} + + +class _MeanPoolBackbone(nn.Module): + """Custom backbone for dependency-injection tests.""" + + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.output_dim = output_dim + self.proj = nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, T, D] + pooled = x.mean(dim=1) + return self.proj(pooled) + + +def _make_batch( + batch_size: int = 4, + seq_len: int = 6, + input_dim: int = 4, + num_classes: int = 3, +): + """Small synthetic batch for testing.""" + x = torch.randn(batch_size, seq_len, input_dim) + y = torch.randint(0, num_classes, (batch_size,)) + return {"signal": x, "label": y} + + +def test_adaptive_transfer_instantiation_default(): + """Test model instantiation with the default backbone.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + model = AdaptiveTransferModel(dataset=dataset, feature_key="signal") + + assert isinstance(model, AdaptiveTransferModel) + assert model.feature_key == "signal" + assert model.label_key == "label" + assert isinstance(model.encoder, nn.LSTM) + + +def test_adaptive_transfer_forward_shapes(): + """Test forward pass output keys and tensor shapes.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + model = AdaptiveTransferModel(dataset=dataset, feature_key="signal") + + batch = _make_batch(batch_size=5, seq_len=7, input_dim=4, num_classes=3) + output = model(**batch) + + assert "logit" in output + assert "y_prob" in output + assert "loss" in output + assert "y_true" in output + + assert output["logit"].shape == (5, 3) + assert output["y_prob"].shape == (5, 3) + assert output["y_true"].shape == (5,) + assert output["loss"].dim() == 0 + + +def test_adaptive_transfer_backward_computes_gradients(): + """Test backward pass and gradient propagation.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + model = AdaptiveTransferModel(dataset=dataset, feature_key="signal") + + batch = _make_batch(batch_size=4, seq_len=5, input_dim=4, num_classes=3) + output = model(**batch) + loss = output["loss"] + loss.backward() + + grads = [ + param.grad + for param in model.parameters() + if param.requires_grad + ] + assert any(grad is not None for grad in grads) + + +def test_adaptive_transfer_custom_backbone_forward(): + """Test dependency injection with a custom backbone.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + backbone = _MeanPoolBackbone(input_dim=4, output_dim=8) + + model = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + backbone=backbone, + backbone_output_dim=8, + ) + + batch = _make_batch(batch_size=3, seq_len=6, input_dim=4, num_classes=3) + output = model(**batch) + + assert output["logit"].shape == (3, 3) + assert output["y_prob"].shape == (3, 3) + + +def test_adaptive_transfer_compute_ipd_with_string_distance(): + """Test IPD computation with a built-in string distance function.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + model = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + distance_fn="cosine", + use_kde_smoothing=False, + ) + + source_batch = {"signal": torch.randn(4, 5, 4)} + target_batch = {"signal": torch.randn(4, 5, 4)} + + ipd = model.compute_ipd(source_batch, target_batch) + + assert isinstance(ipd, float) + assert ipd >= 0.0 + + +def test_adaptive_transfer_compute_ipd_with_callable_distance(): + """Test IPD computation with a custom callable distance function.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + + def l1_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.norm(x - y, p=1, dim=1) + + model = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + distance_fn=l1_distance, + use_kde_smoothing=False, + ) + + source_batch = {"signal": torch.randn(4, 5, 4)} + target_batch = {"signal": torch.randn(4, 5, 4)} + + distances = model.compute_pairwise_distances(source_batch, target_batch) + + assert distances.shape == (4,) + assert torch.all(distances >= 0) + + +def test_adaptive_transfer_rank_source_domains(): + """Test source-domain ranking by similarity.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + model = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + use_kde_smoothing=False, + ) + + target_batch = {"signal": torch.zeros(4, 5, 4)} + + source_batches = [ + {"signal": torch.zeros(4, 5, 4)}, + {"signal": torch.ones(4, 5, 4)}, + {"signal": torch.full((4, 5, 4), 2.0)}, + ] + + ranked = model.rank_source_domains(source_batches, target_batch) + + assert len(ranked) == 3 + assert sorted(ranked) == [0, 1, 2] + assert ranked[0] == 0 + + +def test_adaptive_transfer_get_adaptive_lr(): + """Test similarity-weighted learning-rate scaling.""" + dataset = _DummyDataset(num_classes=3, input_dim=4) + + model_weighted = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + use_similarity_weighting=True, + ) + model_unweighted = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + use_similarity_weighting=False, + ) + + base_lr = 1e-3 + similarity = 2.0 + + assert model_weighted.get_adaptive_lr(base_lr, similarity) == base_lr * similarity + assert model_unweighted.get_adaptive_lr(base_lr, similarity) == base_lr + + +def test_adaptive_transfer_mlp_backbone_forward(): + dataset = _DummyDataset(num_classes=3, input_dim=4) + model = AdaptiveTransferModel( + dataset=dataset, + feature_key="signal", + backbone="mlp", + hidden_dim=32, + ) + batch = _make_batch(batch_size=3, seq_len=8, input_dim=4, num_classes=3) + out = model(**batch) + assert out["logit"].shape == (3, 3) diff --git a/tests/core/test_dsa.py b/tests/core/test_dsa.py new file mode 100644 index 000000000..a99305b0f --- /dev/null +++ b/tests/core/test_dsa.py @@ -0,0 +1,289 @@ +"""Tests for Daily and Sports Activities (DSA) dataset.""" + +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import pandas as pd + +from pyhealth.datasets import DSADataset + +EXPECTED_MANIFEST_COLUMNS = ( + "subject_id", + "activity_name", + "activity_code", + "segment_path", +) + +LOAD_TABLE_COLUMNS = frozenset( + { + "patient_id", + "event_type", + "timestamp", + "segments/segment_path", + "segments/activity_name", + "segments/activity_code", + } +) + +ACTIVITY_RECORD_KEYS = frozenset({"id", "segments"}) + +SEGMENT_RECORD_KEYS = frozenset( + { + "activity", + "data", + "file_path", + "num_samples", + "sampling_rate", + "segment_filename", + "subject_id", + } +) + +EXPECTED_UNIT_KEYS_IN_ORDER = ("T", "RA", "LA", "RL", "LL") +EXPECTED_SENSOR_KEYS_IN_ORDER = ( + "xacc", "yacc", "zacc", + "xgyro", "ygyro", "zgyro", + "xmag", "ymag", "zmag", +) + + +def _write_segment(path: Path, n_rows: int = 125, n_cols: int = 45) -> None: + """Write a synthetic DSA segment file.""" + line = ",".join(["0.0"] * n_cols) + path.write_text("\n".join([line] * n_rows) + "\n", encoding="utf-8") + + +def _make_minimal_dsa_tree(root: Path, activities=None, subjects=None, segments=None) -> Path: + """Create minimal DSA directory structure with configurable layout.""" + activities = activities or ["a01"] + subjects = subjects or ["p1"] + segments = segments or ["s01.txt"] + + first_seg = None + for activity in activities: + for subject in subjects: + for segment in segments: + seg_dir = root / activity / subject + seg_dir.mkdir(parents=True, exist_ok=True) + seg_path = seg_dir / segment + _write_segment(seg_path) + if first_seg is None: + first_seg = seg_path + return first_seg + + +class TestDSADataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmpdir = tempfile.TemporaryDirectory() + cls.root_path = cls._tmpdir.name + # Create minimal tree; DSADataset creates manifest if missing + seg_dir = Path(cls.root_path) / "a01" / "p1" + seg_dir.mkdir(parents=True) + _write_segment(seg_dir / "s01.txt") + + cls.dataset = DSADataset(root=cls.root_path) + + @classmethod + def tearDownClass(cls): + cls._tmpdir.cleanup() + + def test_dataset_initialization(self): + self.assertIsNotNone(self.dataset) + self.assertEqual(self.dataset.dataset_name, "dsa") + self.assertIsNotNone(self.dataset.config) + manifest = Path(self.root_path) / "dsa-pyhealth.csv" + self.assertTrue(manifest.is_file()) + + def test_config_attributes_loaded(self): + """Verify loader attributes (SleepEDF-style constants) are populated.""" + ds = self.dataset + self.assertIsInstance(ds.label_mapping, dict) + self.assertIsInstance(ds.units, list) + self.assertIsInstance(ds.sensors, list) + self.assertEqual(ds.sampling_frequency, 25) + self.assertEqual(ds._num_columns, 45) + self.assertEqual(ds._num_rows, 125) + + def test_get_subject_ids(self): + subject_ids = self.dataset.get_subject_ids() + self.assertIsInstance(subject_ids, list) + self.assertEqual(subject_ids, ["p1"]) + + def test_get_activity_labels(self): + activity_labels = self.dataset.get_activity_labels() + self.assertIsInstance(activity_labels, dict) + self.assertEqual(len(activity_labels), 19) + self.assertEqual(activity_labels.get("sitting"), 0) + + def test_subject_data_loading(self): + subject_ids = self.dataset.get_subject_ids() + self.assertTrue(subject_ids) + subject_id = subject_ids[0] + subject_data = self.dataset.get_subject_data(subject_id) + + self.assertIsInstance(subject_data, dict) + self.assertEqual(subject_data["id"], subject_id) + self.assertIn("activities", subject_data) + self.assertIn("sitting", subject_data["activities"]) + + activity_data = subject_data["activities"]["sitting"] + self.assertIsInstance(activity_data["segments"], list) + self.assertTrue(activity_data["segments"]) + + segment = activity_data["segments"][0] + self.assertIsInstance(segment["data"], np.ndarray) + self.assertEqual(segment["sampling_rate"], 25) + self.assertEqual(segment["data"].shape, (125, 45)) + + def test_segment_schema(self): + """Each segment dict exposes a stable key schema.""" + subject_id = self.dataset.get_subject_ids()[0] + subject_data = self.dataset.get_subject_data(subject_id) + + for activity_name, activity_data in subject_data["activities"].items(): + self.assertEqual(frozenset(activity_data.keys()), ACTIVITY_RECORD_KEYS) + self.assertIsInstance(activity_data["id"], str) + self.assertIsInstance(activity_data["segments"], list) + + for seg in activity_data["segments"]: + self.assertEqual(frozenset(seg.keys()), SEGMENT_RECORD_KEYS) + self.assertEqual(seg["subject_id"], subject_id) + self.assertEqual(seg["activity"], activity_name) + self.assertEqual(seg["segment_filename"], seg["file_path"].name) + self.assertIsInstance(seg["file_path"], Path) + arr = seg["data"] + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.ndim, 2) + self.assertEqual(arr.shape[1], 45, "45 channels per DSA segment row") + self.assertEqual(arr.shape[0], seg["num_samples"]) + + def test_sensor_and_unit_channel_metadata(self): + """Sensors/units lists match module metadata.""" + ds = self.dataset + + self.assertEqual(len(ds.units), len(EXPECTED_UNIT_KEYS_IN_ORDER)) + self.assertEqual(len(ds.sensors), len(EXPECTED_SENSOR_KEYS_IN_ORDER)) + + unit_keys = [list(u.keys())[0] for u in ds.units] + self.assertEqual(tuple(unit_keys), EXPECTED_UNIT_KEYS_IN_ORDER) + + sensor_keys = [list(s.keys())[0] for s in ds.sensors] + self.assertEqual(tuple(sensor_keys), EXPECTED_SENSOR_KEYS_IN_ORDER) + + def test_manifest_csv_columns(self): + manifest = Path(self.root_path) / "dsa-pyhealth.csv" + df = pd.read_csv(manifest) + self.assertEqual(list(df.columns), list(EXPECTED_MANIFEST_COLUMNS)) + self.assertEqual(len(df), 1) + row = df.iloc[0] + self.assertEqual(row["subject_id"], "p1") + self.assertEqual(row["activity_name"], "sitting") + self.assertEqual(row["activity_code"], "A01") + self.assertEqual(row["segment_path"], "a01/p1/s01.txt") + + def test_load_table_manifest_via_base_dataset(self): + df = self.dataset.load_table("segments").compute() + self.assertEqual(frozenset(df.columns), LOAD_TABLE_COLUMNS) + self.assertEqual(len(df), 1) + row = df.iloc[0] + self.assertEqual(row["patient_id"], "p1") + self.assertEqual(row["event_type"], "segments") + self.assertEqual(row["segments/activity_name"], "sitting") + self.assertEqual(row["segments/activity_code"], "A01") + self.assertEqual(row["segments/segment_path"], "a01/p1/s01.txt") + self.assertTrue(pd.isna(row["timestamp"])) + + def test_segment_raises_on_wrong_row_count(self): + with tempfile.TemporaryDirectory() as tmpdir: + seg_path = _make_minimal_dsa_tree(Path(tmpdir)) + _write_segment(seg_path, n_rows=124, n_cols=45) + ds = DSADataset(root=tmpdir) + with self.assertRaisesRegex(ValueError, "has 124 rows, expected 125"): + ds.get_subject_data("p1") + + def test_segment_raises_on_non_numeric_content(self): + with tempfile.TemporaryDirectory() as tmpdir: + seg_path = _make_minimal_dsa_tree(Path(tmpdir)) + bad_line = ",".join(["oops"] + ["0.0"] * 44) + seg_path.write_text("\n".join([bad_line] * 125) + "\n", encoding="utf-8") + ds = DSADataset(root=tmpdir) + with self.assertRaisesRegex(ValueError, "Failed to parse DSA segment"): + ds.get_subject_data("p1") + + def test_segment_raises_on_non_finite_values(self): + with tempfile.TemporaryDirectory() as tmpdir: + seg_path = _make_minimal_dsa_tree(Path(tmpdir)) + nan_line = ",".join(["nan"] + ["0.0"] * 44) + seg_path.write_text("\n".join([nan_line] * 125) + "\n", encoding="utf-8") + ds = DSADataset(root=tmpdir) + with self.assertRaisesRegex(ValueError, "contains non-finite values"): + ds.get_subject_data("p1") + + def test_multiple_subjects_and_activities(self): + """Manifest correctly scans multiple activities and subjects.""" + with tempfile.TemporaryDirectory() as tmpdir: + _make_minimal_dsa_tree( + Path(tmpdir), + activities=["a01", "a02"], + subjects=["p1", "p2"], + segments=["s01.txt", "s02.txt"], + ) + ds = DSADataset(root=tmpdir) + + subjects = ds.get_subject_ids() + self.assertEqual(sorted(subjects), ["p1", "p2"]) + + manifest = Path(tmpdir) / "dsa-pyhealth.csv" + df = pd.read_csv(manifest) + self.assertEqual(len(df), 8) # 2 activities × 2 subjects × 2 segments + + # Check activity names are mapped correctly + self.assertIn("sitting", df["activity_name"].values) + self.assertIn("standing", df["activity_name"].values) + self.assertIn("A01", df["activity_code"].values) + self.assertIn("A02", df["activity_code"].values) + + def test_subject_not_found_raises(self): + """Requesting data for a non-existent subject raises ValueError.""" + with self.assertRaisesRegex(ValueError, "Subject 'nonexistent' not found"): + self.dataset.get_subject_data("nonexistent") + + def test_manifest_raises_on_empty_directory(self): + """Empty directory with no valid segments raises ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaisesRegex(ValueError, "No DSA segments under"): + DSADataset(root=tmpdir) + + def test_segment_column_count_mismatch(self): + """Segment with wrong number of columns raises ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + seg_path = _make_minimal_dsa_tree(Path(tmpdir)) + _write_segment(seg_path, n_rows=125, n_cols=44) + ds = DSADataset(root=tmpdir) + with self.assertRaisesRegex(ValueError, "has 44 columns, expected 45"): + ds.get_subject_data("p1") + + def test_prepare_metadata_scans_standard_layout(self): + """prepare_metadata finds segments using the built-in layout rules.""" + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "a03" / "p5").mkdir(parents=True) + _write_segment(root / "a03" / "p5" / "s10.txt") + + manifest_path = root / "dsa-pyhealth.csv" + config_path = Path(__file__).parent.parent.parent / "pyhealth" / "datasets" / "configs" / "dsa.yaml" + + DSADataset(root=str(root), config_path=str(config_path)) + df = pd.read_csv(manifest_path) + self.assertEqual(len(df), 1) + row = df.iloc[0] + self.assertEqual(row["subject_id"], "p5") + self.assertEqual(row["activity_code"], "A03") + self.assertEqual(row["activity_name"], "lying_on_back") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_dsa_task.py b/tests/core/test_dsa_task.py new file mode 100644 index 000000000..67036e95a --- /dev/null +++ b/tests/core/test_dsa_task.py @@ -0,0 +1,195 @@ +"""Tests for the DSA standalone activity-classification task.""" + +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path + +import numpy as np + +from pyhealth.datasets.dsa import DSADataset +from pyhealth.tasks.dsa_activity_classification import DSAActivityClassification + + +def _write_segment_from_array(path: Path, values: np.ndarray) -> None: + """Write a synthetic DSA segment file from a 2D array.""" + lines = [] + for row in values: + lines.append(",".join(f"{value:.6f}" for value in row)) + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _write_patterned_segment( + path: Path, + n_rows: int = 125, + n_cols: int = 45, + row_scale: float = 1.0, + col_scale: float = 0.1, +) -> np.ndarray: + """Write a segment with deterministic, non-constant channel values.""" + row_offsets = np.arange(n_rows, dtype=np.float64).reshape(-1, 1) * row_scale + col_offsets = np.arange(n_cols, dtype=np.float64).reshape(1, -1) * col_scale + values = row_offsets + col_offsets + _write_segment_from_array(path, values) + return values + + +def _write_constant_segment( + path: Path, + value: float = 7.0, + n_rows: int = 125, + n_cols: int = 45, +) -> np.ndarray: + """Write a segment where every value is identical.""" + values = np.full((n_rows, n_cols), value, dtype=np.float64) + _write_segment_from_array(path, values) + return values + + +def _make_dsa_tree( + root: Path, + activities: tuple[str, ...] = ("a01",), + subjects: tuple[str, ...] = ("p1",), + segments: tuple[str, ...] = ("s01.txt",), + constant: bool = False, +) -> np.ndarray: + """Create a minimal DSA directory tree and return the last segment array.""" + last_values = np.empty((125, 45), dtype=np.float64) + for activity in activities: + for subject in subjects: + for segment in segments: + seg_dir = root / activity / subject + seg_dir.mkdir(parents=True, exist_ok=True) + seg_path = seg_dir / segment + if constant: + last_values = _write_constant_segment(seg_path) + else: + last_values = _write_patterned_segment(seg_path) + return last_values + + +class TestDSAActivityClassification(unittest.TestCase): + """Task-level coverage for DSA activity classification.""" + + def test_invalid_selected_unit_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "Unsupported DSA unit"): + DSAActivityClassification(dataset_root="/tmp/dsa", selected_units=("bad",)) + + def test_invalid_normalization_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "Unsupported normalization"): + DSAActivityClassification( + dataset_root="/tmp/dsa", + normalization="zscore", + ) + + def test_task_generates_expected_samples_for_all_units(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + _make_dsa_tree( + root, + activities=("a01", "a02"), + subjects=("p1",), + segments=("s01.txt", "s02.txt"), + ) + dataset = DSADataset(root=tmpdir) + patient = next(dataset.iter_patients()) + task = DSAActivityClassification(dataset_root=tmpdir) + + samples = task(patient) + + self.assertEqual(len(samples), 4) + sample = samples[0] + self.assertEqual(sample["patient_id"], "p1") + self.assertEqual(sample["subject_id"], "p1") + self.assertEqual(sample["signal"].shape, (125, 45)) + self.assertEqual(sample["num_channels"], 45) + self.assertEqual(sample["unit_combo"], "T+RA+LA+RL+LL") + self.assertIn(sample["activity_name"], {"sitting", "standing"}) + self.assertIn(sample["activity_code"], {"A01", "A02"}) + self.assertIn(sample["label"], {0, 1}) + + def test_selected_unit_extracts_expected_columns_without_normalization( + self, + ) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + raw_values = _make_dsa_tree(root) + dataset = DSADataset(root=tmpdir) + patient = next(dataset.iter_patients()) + task = DSAActivityClassification( + dataset_root=tmpdir, + selected_units=("RA",), + normalization="none", + ) + + samples = task(patient) + + self.assertEqual(len(samples), 1) + expected = raw_values[:, 9:18] + np.testing.assert_allclose(samples[0]["signal"], expected) + self.assertEqual(samples[0]["unit_combo"], "RA") + self.assertEqual(samples[0]["num_channels"], 9) + + def test_minmax_normalization_scales_each_channel_to_minus1_plus1(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + _make_dsa_tree(root) + dataset = DSADataset(root=tmpdir) + patient = next(dataset.iter_patients()) + task = DSAActivityClassification( + dataset_root=tmpdir, + selected_units=("T",), + normalization="minmax", + ) + + samples = task(patient) + signal = samples[0]["signal"] + + np.testing.assert_allclose(signal.min(axis=0), -1.0) + np.testing.assert_allclose(signal.max(axis=0), 1.0) + + def test_constant_channels_become_zeros_after_minmax_normalization(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + _make_dsa_tree(root, constant=True) + dataset = DSADataset(root=tmpdir) + patient = next(dataset.iter_patients()) + task = DSAActivityClassification( + dataset_root=tmpdir, + selected_units=("T",), + normalization="minmax", + ) + + samples = task(patient) + + np.testing.assert_allclose(samples[0]["signal"], 0.0) + + def test_set_task_builds_sample_dataset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + _make_dsa_tree( + root, + activities=("a01", "a02"), + subjects=("p1", "p2"), + segments=("s01.txt",), + ) + dataset = DSADataset(root=tmpdir) + task = DSAActivityClassification( + dataset_root=tmpdir, + selected_units=("T", "RA"), + ) + + sample_dataset = dataset.set_task(task, num_workers=1) + + self.assertEqual(len(sample_dataset), 4) + self.assertIn("signal", sample_dataset.input_processors) + self.assertIn("label", sample_dataset.output_processors) + + sample = sample_dataset[0] + self.assertEqual(tuple(sample["signal"].shape), (125, 18)) + self.assertIn(int(sample["label"]), {0, 1}) + + +if __name__ == "__main__": + unittest.main()