From 822bf2016fbb8cad40966427b02065685c4890dd Mon Sep 17 00:00:00 2001 From: suraj kumarwq Date: Sun, 19 Apr 2026 16:15:09 -0400 Subject: [PATCH 1/2] Add DREAMTSleepClassification standalone task with ablation study Implements IBI-based sleep staging task for the DREAMT wearable dataset following the WatchSleepNet preprocessing pipeline (Wang et al., 2025). Includes 3-class/4-class label configs, optional accelerometer input, configurable epoch duration, 42 unit tests, and example ablation script. feat: switch payment submission email --- docs/api/tasks.rst | 1 + ...health.tasks.DREAMTSleepClassification.rst | 7 + examples/dreamt_sleep_classification.ipynb | 1542 +++++++++++++++++ examples/dreamt_sleep_classification.py | 234 +++ examples/dreamt_sleep_staging_rnn.ipynb | 390 +++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/dreamt_sleep_classification.py | 224 +++ tests/core/test_dreamt.py | 302 +++- 8 files changed, 2663 insertions(+), 38 deletions(-) create mode 100644 docs/api/tasks/pyhealth.tasks.DREAMTSleepClassification.rst create mode 100644 examples/dreamt_sleep_classification.ipynb create mode 100644 examples/dreamt_sleep_classification.py create mode 100644 examples/dreamt_sleep_staging_rnn.ipynb create mode 100644 pyhealth/tasks/dreamt_sleep_classification.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..cb0f789ea 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -220,6 +220,7 @@ Available Tasks Readmission Prediction Sleep Staging Sleep Staging (SleepEDF) + DREAMT Sleep Classification Temple University EEG Tasks Sleep Staging v2 Benchmark EHRShot diff --git a/docs/api/tasks/pyhealth.tasks.DREAMTSleepClassification.rst b/docs/api/tasks/pyhealth.tasks.DREAMTSleepClassification.rst new file mode 100644 index 000000000..4d00600de --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DREAMTSleepClassification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.DREAMTSleepClassification +========================================= + +.. autoclass:: pyhealth.tasks.dreamt_sleep_classification.DREAMTSleepClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/dreamt_sleep_classification.ipynb b/examples/dreamt_sleep_classification.ipynb new file mode 100644 index 000000000..f25a3fad5 --- /dev/null +++ b/examples/dreamt_sleep_classification.ipynb @@ -0,0 +1,1542 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f13747dd", + "metadata": {}, + "source": [ + "# DREAMT Sleep Staging — WatchSleepNet Task Ablations\n", + "\n", + "**Paper:** Wang et al., *WatchSleepNet: A Novel Model and Pretraining Approach for Advancing Sleep Staging with Smartwatches*, 2025. \n", + "https://doi.org/10.48550/arXiv.2501.17268\n", + "\n", + "**Dataset:** DREAMT (PhysioNet) — https://physionet.org/content/dreamt/\n", + "\n", + "This notebook demonstrates the `SleepStagingDREAMT` task on the DREAMT wearable dataset and includes three novel ablation studies **not present in the original paper**:\n", + "\n", + "| Ablation | Variable | Paper default | What we test |\n", + "|---|---|---|---|\n", + "| 1 — Label granularity | `num_classes` | 3 (Wake/NREM/REM) | 3-class vs 4-class (N1/N2/N3 split) |\n", + "| 2 — Accelerometer | `use_accelerometer` | False (IBI only) | IBI-only vs IBI + ACC_X/Y/Z |\n", + "| 3 — Epoch duration | `epoch_seconds` | 30 s | 15 s / 30 s / 60 s |\n", + "\n", + "**Quick start (no download required):** Set `USE_DEMO = True` below. \n", + "**Real data:** Set `USE_DEMO = False` and point `DREAMT_ROOT` at your local DREAMT 2.1.0 directory." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ad0a8977", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Configuration ─────────────────────────────────────────────────────────────\n", + "USE_DEMO = False # True → synthetic data (no download needed)\n", + " # False → set DREAMT_ROOT to your local path\n", + "DREAMT_ROOT = \"/home/suraj/Documents/code/cs598 DLH/dreamt/2.1.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "aae93ea2", + "metadata": {}, + "outputs": [], + "source": [ + "import collections\n", + "import os\n", + "import shutil\n", + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "id": "89a72074", + "metadata": {}, + "source": [ + "## Demo mode: synthetic DREAMT directory\n", + "\n", + "When `USE_DEMO = True` we build a minimal DREAMT directory in a temp folder so the notebook is fully self-contained. Each synthetic patient has 60 epochs of 30 s (3 840 rows at 64 Hz) with stages cycling through W / N1 / N2 / N3 / R." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5d5c6de2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using root: /home/suraj/Documents/code/cs598 DLH/dreamt/2.1.0\n" + ] + } + ], + "source": [ + "_demo_tmpdir = None\n", + "\n", + "def _build_demo_root(n_patients: int = 6, n_rows: int = 3840) -> str:\n", + " \"\"\"Create a minimal synthetic DREAMT directory tree.\"\"\"\n", + " global _demo_tmpdir\n", + " _demo_tmpdir = tempfile.mkdtemp(prefix=\"dreamt_demo_\")\n", + " root = Path(_demo_tmpdir)\n", + " (root / \"data_64Hz\").mkdir()\n", + " (root / \"data_100Hz\").mkdir()\n", + "\n", + " rng = np.random.default_rng(0)\n", + " stage_cycle = (\n", + " [\"W\"] * 640 + [\"N1\"] * 640 + [\"N2\"] * 640 + [\"N3\"] * 640 + [\"R\"] * 640\n", + " ) * 2 # 5 × 640 = 3 200 rows; repeated so n_rows=3 840 is covered\n", + "\n", + " rows = []\n", + " for i in range(1, n_patients + 1):\n", + " sid = f\"S{i:03d}\"\n", + "\n", + " ibi = np.zeros(n_rows)\n", + " beat_idx = np.arange(0, n_rows, 51) # ~1 beat per 0.8 s\n", + " ibi[beat_idx] = rng.uniform(0.7, 1.1, len(beat_idx))\n", + "\n", + " df = pd.DataFrame({\n", + " \"TIMESTAMP\": np.arange(n_rows) / 64.0,\n", + " \"BVP\": rng.standard_normal(n_rows),\n", + " \"HR\": rng.integers(50, 90, n_rows).astype(float),\n", + " \"EDA\": rng.uniform(0.0, 1.0, n_rows),\n", + " \"TEMP\": rng.uniform(33.0, 37.0, n_rows),\n", + " \"ACC_X\": rng.standard_normal(n_rows),\n", + " \"ACC_Y\": rng.standard_normal(n_rows),\n", + " \"ACC_Z\": rng.standard_normal(n_rows),\n", + " \"IBI\": ibi,\n", + " \"Sleep_Stage\": stage_cycle[:n_rows],\n", + " })\n", + " df.to_csv(root / \"data_64Hz\" / f\"{sid}_whole_df.csv\", index=False)\n", + " pd.DataFrame({\"a\": [1]}).to_csv(\n", + " root / \"data_100Hz\" / f\"{sid}_PSG_df.csv\", index=False\n", + " )\n", + "\n", + " rows.append({\n", + " \"SID\": sid, \"AGE\": rng.integers(25, 65),\n", + " \"GENDER\": rng.choice([\"M\", \"F\"]), \"BMI\": rng.integers(18, 40),\n", + " \"OAHI\": rng.integers(0, 30), \"AHI\": rng.integers(0, 30),\n", + " \"Mean_SaO2\": f\"{rng.integers(90, 99)}%\",\n", + " \"Arousal Index\": rng.integers(5, 30),\n", + " \"MEDICAL_HISTORY\": \"None\", \"Sleep_Disorders\": \"None\",\n", + " })\n", + "\n", + " pd.DataFrame(rows).to_csv(root / \"participant_info.csv\", index=False)\n", + " print(f\"[demo] Synthetic DREAMT root: {root}\")\n", + " return str(root)\n", + "\n", + "\n", + "root = _build_demo_root() if USE_DEMO else DREAMT_ROOT\n", + "print(f\"Using root: {root}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7683a956", + "metadata": {}, + "source": [ + "## Step 1 — Load DREAMTDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1c09cafb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config provided, using default config\n", + "Initializing dreamt_sleep dataset from /home/suraj/Documents/code/cs598 DLH/dreamt/2.1.0 (dev mode: False)\n", + "No cache_dir provided. Using default cache dir: /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd\n", + "Found cached event dataframe: /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/global_event_df.parquet\n", + "Dataset: dreamt_sleep\n", + "Dev mode: False\n", + "Number of patients: 100\n", + "Number of events: 100\n", + "Found 100 unique patient IDs\n", + "Patients loaded: 100\n" + ] + } + ], + "source": [ + "from pyhealth.datasets import DREAMTDataset\n", + "\n", + "dreamt = DREAMTDataset(root=root)\n", + "dreamt.stats()\n", + "print(f\"Patients loaded: {len(dreamt.unique_patient_ids)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c2a242ae", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.tasks import SleepStagingDREAMT\n", + "\n", + "def _all_samples(ds):\n", + " return [ds[i] for i in range(len(ds))]\n", + "\n", + "def summarise(task_ds, name: str) -> None:\n", + " \"\"\"Print epoch count and class distribution for a task dataset.\"\"\"\n", + " all_s = _all_samples(task_ds)\n", + " n = len(all_s)\n", + " counts = dict(sorted(collections.Counter(s[\"label\"].item() for s in all_s).items()))\n", + " print(f\" [{name}]\")\n", + " print(f\" Total epochs : {n}\")\n", + " print(f\" Label dist : {counts}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "ad3511ef", + "metadata": {}, + "source": [ + "---\n", + "## Ablation 1 — Label Granularity: 3-class vs 4-class\n", + "\n", + "The paper uses **3-class** staging (Wake / NREM / REM), merging N1, N2, and N3 into a single NREM class. \n", + "We test whether separating NREM into its constituent stages (**4-class**: Wake / N1 / N2 / N3 / REM) improves clinical granularity, at the cost of a harder classification problem.\n", + "\n", + "**Hypothesis:** finer labels give a model more signal to differentiate NREM depth but may hurt overall accuracy due to inter-stage similarity." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "994408c3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_027e375b-8dc4-5f11-8e2e-80003167e8fc/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_027e375b-8dc4-5f11-8e2e-80003167e8fc/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_027e375b-8dc4-5f11-8e2e-80003167e8fc/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Label granularity comparison:\n", + " [3-class Wake=0 / NREM=1 / REM=2]\n", + " Total epochs : 5601\n", + " Label dist : {0: 1267, 1: 3768, 2: 566}\n", + " [4-class Wake=0 / N1=1 / N2=2 / N3=3 / REM=4]\n", + " Total epochs : 5601\n", + " Label dist : {0: 1267, 1: 517, 2: 2788, 3: 463, 4: 566}\n", + "\n", + "Observation: both datasets share the same epoch count; 4-class spreads NREM epochs across three labels.\n" + ] + } + ], + "source": [ + "task_3cls = SleepStagingDREAMT(num_classes=3)\n", + "task_4cls = SleepStagingDREAMT(num_classes=4)\n", + "\n", + "ds_3cls = dreamt.set_task(task_3cls)\n", + "ds_4cls = dreamt.set_task(task_4cls)\n", + "\n", + "print(\"Label granularity comparison:\")\n", + "summarise(ds_3cls, \"3-class Wake=0 / NREM=1 / REM=2\")\n", + "summarise(ds_4cls, \"4-class Wake=0 / N1=1 / N2=2 / N3=3 / REM=4\")\n", + "\n", + "print(\n", + " \"\\nObservation: both datasets share the same epoch count; \"\n", + " \"4-class spreads NREM epochs across three labels.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fda7b3e9", + "metadata": {}, + "source": [ + "---\n", + "## Ablation 2 — Accelerometer Augmentation: IBI-only vs IBI + ACC\n", + "\n", + "The paper uses only **IBI** (Inter-Beat Interval) as the model input. \n", + "We test whether adding raw wrist **accelerometer** signals (ACC_X / ACC_Y / ACC_Z) improves **Wake detection**, since physical movement is a strong wakefulness indicator.\n", + "\n", + "**Hypothesis:** ACC data captures motion patterns invisible to cardiac signals, boosting Wake F1 without hurting NREM/REM classification." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fbf6ddc5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_2c6c084c-4940-5c1e-b12d-a7e2f5a3279e/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_2c6c084c-4940-5c1e-b12d-a7e2f5a3279e/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_2c6c084c-4940-5c1e-b12d-a7e2f5a3279e/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "Accelerometer augmentation comparison:\n", + " [IBI-only input keys: ibi_sequence]\n", + " Total epochs : 5601\n", + " Label dist : {0: 1267, 1: 3768, 2: 566}\n", + " [IBI + ACC input keys: ibi_sequence, accelerometer]\n", + " Total epochs : 5601\n", + " Label dist : {0: 1267, 1: 3768, 2: 566}\n", + "ACC tensor shape per epoch: torch.Size([1920, 3]) (rows × 3 axes)\n", + "To train: replace feature_keys=['ibi_sequence'] with ['ibi_sequence', 'accelerometer'] and compare Wake F1.\n" + ] + } + ], + "source": [ + "task_ibi_only = SleepStagingDREAMT(num_classes=3, use_accelerometer=False)\n", + "task_ibi_acc = SleepStagingDREAMT(num_classes=3, use_accelerometer=True)\n", + "\n", + "ds_ibi_only = dreamt.set_task(task_ibi_only)\n", + "ds_ibi_acc = dreamt.set_task(task_ibi_acc)\n", + "\n", + "print(\"Accelerometer augmentation comparison:\")\n", + "summarise(ds_ibi_only, \"IBI-only input keys: ibi_sequence\")\n", + "summarise(ds_ibi_acc, \"IBI + ACC input keys: ibi_sequence, accelerometer\")\n", + "\n", + "acc_samples = _all_samples(ds_ibi_acc)\n", + "if acc_samples:\n", + " acc_shape = acc_samples[0][\"accelerometer\"].shape\n", + " print(f\"ACC tensor shape per epoch: {acc_shape} (rows × 3 axes)\")\n", + "\n", + "print(\n", + " \"To train: replace feature_keys=['ibi_sequence'] with \"\n", + " \"['ibi_sequence', 'accelerometer'] and compare Wake F1.\"\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "edaa4b40", + "metadata": {}, + "source": [ + "---\n", + "## Ablation 3 — Epoch Duration: 15 s / 30 s / 60 s\n", + "\n", + "The paper fixes each epoch at **30 seconds** (the PSG standard). \n", + "We test shorter (15 s) and longer (60 s) windows to explore the tradeoff between temporal resolution and per-epoch IBI context.\n", + "\n", + "**Hypothesis:** shorter windows increase epoch count and temporal resolution but give the model fewer heartbeats per sample; longer windows provide richer IBI context but may blur stage transitions." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "64e6f03e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch (s) Total epochs Avg IBI vals/epoch \n", + "---------------------------------------------\n", + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_45108b4a-f7d1-5417-bb2c-46e90a4a73fb/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_45108b4a-f7d1-5417-bb2c-46e90a4a73fb/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_45108b4a-f7d1-5417-bb2c-46e90a4a73fb/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "15 11206 960.0 \n", + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_adb2c321-b696-5572-ab41-937abd1edbf4/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "30 5601 1920.0 ← paper default\n", + "Setting task SleepStagingDREAMT for dreamt_sleep base dataset...\n", + "Task cache paths: task_df=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_0d3f72cf-7651-5429-ba50-03a68a245d95/task_df.ld, samples=/home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_0d3f72cf-7651-5429-ba50-03a68a245d95/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached processed samples at /home/suraj/.cache/pyhealth/710f5f97-809b-59db-9c21-edc3a3612dbd/tasks/SleepStagingDREAMT_0d3f72cf-7651-5429-ba50-03a68a245d95/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.\n", + "60 2799 3840.0 \n", + "Observation: halving epoch duration doubles epoch count but halves the average IBI count per window.\n" + ] + } + ], + "source": [ + "print(f\"{'Epoch (s)':<10} {'Total epochs':<15} {'Avg IBI vals/epoch':<20}\")\n", + "print(\"-\" * 45)\n", + "\n", + "for epoch_secs in (15, 30, 60):\n", + " task_ep = SleepStagingDREAMT(epoch_seconds=epoch_secs, num_classes=3)\n", + " ds_ep = dreamt.set_task(task_ep)\n", + " ep_samples = _all_samples(ds_ep)\n", + " n = len(ep_samples)\n", + " avg_ibi = (\n", + " np.mean([len(s[\"ibi_sequence\"]) for s in ep_samples])\n", + " if ep_samples else 0.0\n", + " )\n", + " paper_marker = \" ← paper default\" if epoch_secs == 30 else \"\"\n", + " print(f\"{epoch_secs:<10} {n:<15} {avg_ibi:<20.1f}{paper_marker}\")\n", + "\n", + "print(\n", + " \"Observation: halving epoch duration doubles epoch count \"\n", + " \"but halves the average IBI count per window.\"\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "8927fa9d", + "metadata": {}, + "source": [ + "---\n", + "## Step 2 — Train a lightweight RNN on the 3-class task\n", + "\n", + "We use PyHealth's built-in **RNN** model as a stand-in for the WatchSleepNet encoder,\n", + "applied to the variable-length IBI sequence of each epoch. \n", + "This validates the full data → task → model → evaluation pipeline end-to-end." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5860de3d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Split -- train: 3124 val: 842 test: 1635 epochs\n", + "RNN(\n", + " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(\n", + " (ibi_sequence): Linear(in_features=1920, out_features=128, bias=True)\n", + " ))\n", + " (rnn): ModuleDict(\n", + " (ibi_sequence): RNNLayer(\n", + " (dropout_layer): Dropout(p=0.5, inplace=False)\n", + " (rnn): GRU(128, 128, batch_first=True)\n", + " )\n", + " )\n", + " (fc): Linear(in_features=128, out_features=3, bias=True)\n", + ")\n", + "Metrics: None\n", + "Device: cpu\n", + "\n", + "Training:\n", + "Batch size: 32\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: accuracy\n", + "Monitor criterion: max\n", + "Epochs: 20\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0ba023ddc7e4494e8bf94fc3ae2f9dc4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 20: 0%| | 0/98 [00:00 str: + """Create a minimal synthetic DREAMT directory tree for demo mode.""" + tmp = tempfile.mkdtemp(prefix="dreamt_demo_") + root = Path(tmp) + (root / "data_64Hz").mkdir() + (root / "data_100Hz").mkdir() + + rng = np.random.default_rng(0) + stage_cycle = ( + ["W"] * 640 + ["N1"] * 640 + ["N2"] * 640 + ["N3"] * 640 + ["R"] * 640 + ) * 2 + + rows = [] + for i in range(1, n_patients + 1): + sid = f"S{i:03d}" + ibi = np.zeros(n_rows) + beat_idx = np.arange(0, n_rows, 51) + ibi[beat_idx] = rng.uniform(0.7, 1.1, len(beat_idx)) + + df = pd.DataFrame({ + "TIMESTAMP": np.arange(n_rows) / 64.0, + "BVP": rng.standard_normal(n_rows), + "HR": rng.integers(50, 90, n_rows).astype(float), + "EDA": rng.uniform(0.0, 1.0, n_rows), + "TEMP": rng.uniform(33.0, 37.0, n_rows), + "ACC_X": rng.standard_normal(n_rows), + "ACC_Y": rng.standard_normal(n_rows), + "ACC_Z": rng.standard_normal(n_rows), + "IBI": ibi, + "Sleep_Stage": stage_cycle[:n_rows], + }) + df.to_csv(root / "data_64Hz" / f"{sid}_whole_df.csv", index=False) + pd.DataFrame({"a": [1]}).to_csv( + root / "data_100Hz" / f"{sid}_PSG_df.csv", index=False + ) + rows.append({ + "SID": sid, "AGE": int(rng.integers(25, 65)), + "GENDER": rng.choice(["M", "F"]), "BMI": int(rng.integers(18, 40)), + "OAHI": int(rng.integers(0, 30)), "AHI": int(rng.integers(0, 30)), + "Mean_SaO2": f"{int(rng.integers(90, 99))}%", + "Arousal Index": int(rng.integers(5, 30)), + "MEDICAL_HISTORY": "None", "Sleep_Disorders": "None", + }) + + pd.DataFrame(rows).to_csv(root / "participant_info.csv", index=False) + print(f"[demo] Synthetic DREAMT root: {root}") + return tmp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _all_samples(ds): + return [ds[i] for i in range(len(ds))] + + +def summarise(ds, name: str) -> None: + all_s = _all_samples(ds) + n = len(all_s) + counts = dict(sorted( + collections.Counter(s["label"].item() for s in all_s).items() + )) + print(f" [{name}]") + print(f" Total epochs : {n}") + print(f" Label dist : {counts}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(root: str) -> None: + from pyhealth.datasets import DREAMTDataset + from pyhealth.tasks import DREAMTSleepClassification + + # Step 1 — Load dataset + print("\n=== Step 1 — Load DREAMTDataset ===") + dreamt = DREAMTDataset(root=root) + dreamt.stats() + print(f"Patients loaded: {len(dreamt.unique_patient_ids)}") + + # Ablation 1 — Label granularity + print("\n=== Ablation 1 — Label Granularity ===") + ds_3cls = dreamt.set_task(DREAMTSleepClassification(num_classes=3)) + ds_4cls = dreamt.set_task(DREAMTSleepClassification(num_classes=4)) + print("Label granularity comparison:") + summarise(ds_3cls, "3-class Wake=0 / NREM=1 / REM=2") + summarise(ds_4cls, "4-class Wake=0 / N1=1 / N2=2 / N3=3 / REM=4") + print( + "\nObservation: both datasets share the same epoch count; " + "4-class spreads NREM epochs across three labels." + ) + + # Ablation 2 — Accelerometer augmentation + print("\n=== Ablation 2 — Accelerometer Augmentation ===") + ds_ibi_only = dreamt.set_task(DREAMTSleepClassification(num_classes=3, use_accelerometer=False)) + ds_ibi_acc = dreamt.set_task(DREAMTSleepClassification(num_classes=3, use_accelerometer=True)) + print("Accelerometer augmentation comparison:") + summarise(ds_ibi_only, "IBI-only input keys: ibi_sequence") + summarise(ds_ibi_acc, "IBI + ACC input keys: ibi_sequence, accelerometer") + acc_samples = _all_samples(ds_ibi_acc) + if acc_samples: + print(f"\nACC tensor shape per epoch: {acc_samples[0]['accelerometer'].shape} (rows x 3 axes)") + print( + "\nTo train with ACC: replace feature_keys=['ibi_sequence'] with " + "['ibi_sequence', 'accelerometer'] and compare Wake F1." + ) + + # Ablation 3 — Epoch duration + print("\n=== Ablation 3 — Epoch Duration ===") + print(f"{'Epoch (s)':<10} {'Total epochs':<15} {'Avg IBI vals/epoch':<20}") + print("-" * 45) + for epoch_secs in (15, 30, 60): + ds_ep = dreamt.set_task(DREAMTSleepClassification(epoch_seconds=epoch_secs, num_classes=3)) + ep_samples = _all_samples(ds_ep) + n = len(ep_samples) + avg_ibi = ( + np.mean([len(s["ibi_sequence"]) for s in ep_samples]) + if ep_samples else 0.0 + ) + marker = " <- paper default" if epoch_secs == 30 else "" + print(f"{epoch_secs:<10} {n:<15} {avg_ibi:<20.1f}{marker}") + print( + "\nObservation: halving epoch duration doubles epoch count " + "but halves the average IBI count per window." + ) + + # Step 2 — Train RNN + print("\n=== Step 2 — Train RNN on 3-class task ===") + from pyhealth.datasets import get_dataloader, split_by_patient + from pyhealth.models import RNN + from pyhealth.trainer import Trainer + from sklearn.metrics import cohen_kappa_score, f1_score + import torch + + train_ds, val_ds, test_ds = split_by_patient(ds_3cls, [0.7, 0.15, 0.15]) + train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + print(f"Split -- train: {len(train_ds)} val: {len(val_ds)} test: {len(test_ds)} epochs") + + model = RNN(dataset=ds_3cls) + trainer = Trainer(model=model, device="cpu") + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=3, + monitor="accuracy", + ) + + results = trainer.evaluate(test_loader) + print(f"\nTest results: {results}") + + # Per-class metrics + all_preds, all_labels = [], [] + model.eval() + with torch.no_grad(): + for batch in test_loader: + output = model(**batch) + preds = output["y_prob"].argmax(dim=1).cpu().numpy() + labels = batch["label"].cpu().numpy() + all_preds.extend(preds) + all_labels.extend(labels) + + per_class_f1 = f1_score(all_labels, all_preds, average=None, labels=[0, 1, 2]) + kappa = cohen_kappa_score(all_labels, all_preds) + acc = sum(p == l for p, l in zip(all_preds, all_labels)) / len(all_labels) + + print(f"\nAccuracy : {acc:.4f}") + print(f"Wake F1 : {per_class_f1[0]:.4f}") + print(f"NREM F1 : {per_class_f1[1]:.4f}") + print(f"REM F1 : {per_class_f1[2]:.4f}") + print(f"Cohen Kappa : {kappa:.4f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="DREAMT sleep staging ablations") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--demo", action="store_true", help="Run with synthetic data") + group.add_argument("--root", type=str, help="Path to local DREAMT 2.1.0 directory") + args = parser.parse_args() + + tmpdir = None + if args.demo: + tmpdir = _build_demo_root() + root = tmpdir + else: + root = args.root + + try: + main(root) + finally: + if tmpdir: + shutil.rmtree(tmpdir) + print(f"\n[demo] Cleaned up {tmpdir}") diff --git a/examples/dreamt_sleep_staging_rnn.ipynb b/examples/dreamt_sleep_staging_rnn.ipynb new file mode 100644 index 000000000..85b3ee9f4 --- /dev/null +++ b/examples/dreamt_sleep_staging_rnn.ipynb @@ -0,0 +1,390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ee5de951", + "metadata": {}, + "source": [ + "# DREAMT Sleep Staging — WatchSleepNet Task Ablations\n", + "\n", + "**Paper:** Wang et al., *WatchSleepNet: A Novel Model and Pretraining Approach for Advancing Sleep Staging with Smartwatches*, 2025. \n", + "https://doi.org/10.48550/arXiv.2501.17268\n", + "\n", + "**Dataset:** DREAMT (PhysioNet) — https://physionet.org/content/dreamt/\n", + "\n", + "This notebook demonstrates the `SleepStagingDREAMT` task on the DREAMT wearable dataset and includes three novel ablation studies **not present in the original paper**:\n", + "\n", + "| Ablation | Variable | Paper default | What we test |\n", + "|---|---|---|---|\n", + "| 1 — Label granularity | `num_classes` | 3 (Wake/NREM/REM) | 3-class vs 4-class (N1/N2/N3 split) |\n", + "| 2 — Accelerometer | `use_accelerometer` | False (IBI only) | IBI-only vs IBI + ACC_X/Y/Z |\n", + "| 3 — Epoch duration | `epoch_seconds` | 30 s | 15 s / 30 s / 60 s |\n", + "\n", + "**Quick start (no download required):** Set `USE_DEMO = True` below. \n", + "**Real data:** Set `USE_DEMO = False` and point `DREAMT_ROOT` at your local DREAMT 2.1.0 directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3388d50", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Configuration ─────────────────────────────────────────────────────────────\n", + "USE_DEMO = True # True → synthetic data (no download needed)\n", + " # False → set DREAMT_ROOT to your local path\n", + "DREAMT_ROOT = \"/path/to/dreamt/2.1.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9194eb74", + "metadata": {}, + "outputs": [], + "source": [ + "import collections\n", + "import os\n", + "import shutil\n", + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "id": "3f6e53f6", + "metadata": {}, + "source": [ + "## Demo mode: synthetic DREAMT directory\n", + "\n", + "When `USE_DEMO = True` we build a minimal DREAMT directory in a temp folder so the notebook is fully self-contained. Each synthetic patient has 60 epochs of 30 s (3 840 rows at 64 Hz) with stages cycling through W / N1 / N2 / N3 / R." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90dba6a5", + "metadata": {}, + "outputs": [], + "source": [ + "_demo_tmpdir = None\n", + "\n", + "def _build_demo_root(n_patients: int = 6, n_rows: int = 3840) -> str:\n", + " \"\"\"Create a minimal synthetic DREAMT directory tree.\"\"\"\n", + " global _demo_tmpdir\n", + " _demo_tmpdir = tempfile.mkdtemp(prefix=\"dreamt_demo_\")\n", + " root = Path(_demo_tmpdir)\n", + " (root / \"data_64Hz\").mkdir()\n", + " (root / \"data_100Hz\").mkdir()\n", + "\n", + " rng = np.random.default_rng(0)\n", + " stage_cycle = (\n", + " [\"W\"] * 640 + [\"N1\"] * 640 + [\"N2\"] * 640 + [\"N3\"] * 640 + [\"R\"] * 640\n", + " ) * 2 # 5 × 640 = 3 200 rows; repeated so n_rows=3 840 is covered\n", + "\n", + " rows = []\n", + " for i in range(1, n_patients + 1):\n", + " sid = f\"S{i:03d}\"\n", + "\n", + " ibi = np.zeros(n_rows)\n", + " beat_idx = np.arange(0, n_rows, 51) # ~1 beat per 0.8 s\n", + " ibi[beat_idx] = rng.uniform(0.7, 1.1, len(beat_idx))\n", + "\n", + " df = pd.DataFrame({\n", + " \"TIMESTAMP\": np.arange(n_rows) / 64.0,\n", + " \"BVP\": rng.standard_normal(n_rows),\n", + " \"HR\": rng.integers(50, 90, n_rows).astype(float),\n", + " \"EDA\": rng.uniform(0.0, 1.0, n_rows),\n", + " \"TEMP\": rng.uniform(33.0, 37.0, n_rows),\n", + " \"ACC_X\": rng.standard_normal(n_rows),\n", + " \"ACC_Y\": rng.standard_normal(n_rows),\n", + " \"ACC_Z\": rng.standard_normal(n_rows),\n", + " \"IBI\": ibi,\n", + " \"Sleep_Stage\": stage_cycle[:n_rows],\n", + " })\n", + " df.to_csv(root / \"data_64Hz\" / f\"{sid}_whole_df.csv\", index=False)\n", + " pd.DataFrame({\"a\": [1]}).to_csv(\n", + " root / \"data_100Hz\" / f\"{sid}_PSG_df.csv\", index=False\n", + " )\n", + "\n", + " rows.append({\n", + " \"SID\": sid, \"AGE\": rng.integers(25, 65),\n", + " \"GENDER\": rng.choice([\"M\", \"F\"]), \"BMI\": rng.integers(18, 40),\n", + " \"OAHI\": rng.integers(0, 30), \"AHI\": rng.integers(0, 30),\n", + " \"Mean_SaO2\": f\"{rng.integers(90, 99)}%\",\n", + " \"Arousal Index\": rng.integers(5, 30),\n", + " \"MEDICAL_HISTORY\": \"None\", \"Sleep_Disorders\": \"None\",\n", + " })\n", + "\n", + " pd.DataFrame(rows).to_csv(root / \"participant_info.csv\", index=False)\n", + " print(f\"[demo] Synthetic DREAMT root: {root}\")\n", + " return str(root)\n", + "\n", + "\n", + "root = _build_demo_root() if USE_DEMO else DREAMT_ROOT\n", + "print(f\"Using root: {root}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cbfec1d1", + "metadata": {}, + "source": [ + "## Step 1 — Load DREAMTDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b17010f8", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import DREAMTDataset\n", + "\n", + "dreamt = DREAMTDataset(root=root)\n", + "dreamt.stats()\n", + "print(f\"Patients loaded: {len(dreamt.unique_patient_ids)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3509c4f7", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.tasks import SleepStagingDREAMT\n", + "\n", + "def summarise(task_ds, name: str) -> None:\n", + " \"\"\"Print epoch count and class distribution for a task dataset.\"\"\"\n", + " n = len(task_ds.samples)\n", + " counts = dict(sorted(collections.Counter(s[\"label\"] for s in task_ds.samples).items()))\n", + " print(f\" [{name}]\")\n", + " print(f\" Total epochs : {n}\")\n", + " print(f\" Label dist : {counts}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3735dcec", + "metadata": {}, + "source": [ + "---\n", + "## Ablation 1 — Label Granularity: 3-class vs 4-class\n", + "\n", + "The paper uses **3-class** staging (Wake / NREM / REM), merging N1, N2, and N3 into a single NREM class. \n", + "We test whether separating NREM into its constituent stages (**4-class**: Wake / N1 / N2 / N3 / REM) improves clinical granularity, at the cost of a harder classification problem.\n", + "\n", + "**Hypothesis:** finer labels give a model more signal to differentiate NREM depth but may hurt overall accuracy due to inter-stage similarity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f50df65d", + "metadata": {}, + "outputs": [], + "source": [ + "task_3cls = SleepStagingDREAMT(num_classes=3)\n", + "task_4cls = SleepStagingDREAMT(num_classes=4)\n", + "\n", + "ds_3cls = dreamt.set_task(task_3cls)\n", + "ds_4cls = dreamt.set_task(task_4cls)\n", + "\n", + "print(\"Label granularity comparison:\")\n", + "summarise(ds_3cls, \"3-class Wake=0 / NREM=1 / REM=2\")\n", + "summarise(ds_4cls, \"4-class Wake=0 / N1=1 / N2=2 / N3=3 / REM=4\")\n", + "\n", + "print(\n", + " \"\\nObservation: both datasets share the same epoch count; \"\n", + " \"4-class spreads NREM epochs across three labels.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d8fbd49d", + "metadata": {}, + "source": [ + "---\n", + "## Ablation 2 — Accelerometer Augmentation: IBI-only vs IBI + ACC\n", + "\n", + "The paper uses only **IBI** (Inter-Beat Interval) as the model input. \n", + "We test whether adding raw wrist **accelerometer** signals (ACC_X / ACC_Y / ACC_Z) improves **Wake detection**, since physical movement is a strong wakefulness indicator.\n", + "\n", + "**Hypothesis:** ACC data captures motion patterns invisible to cardiac signals, boosting Wake F1 without hurting NREM/REM classification." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53abea70", + "metadata": {}, + "outputs": [], + "source": [ + "task_ibi_only = SleepStagingDREAMT(num_classes=3, use_accelerometer=False)\n", + "task_ibi_acc = SleepStagingDREAMT(num_classes=3, use_accelerometer=True)\n", + "\n", + "ds_ibi_only = dreamt.set_task(task_ibi_only)\n", + "ds_ibi_acc = dreamt.set_task(task_ibi_acc)\n", + "\n", + "print(\"Accelerometer augmentation comparison:\")\n", + "summarise(ds_ibi_only, \"IBI-only input keys: ibi_sequence\")\n", + "summarise(ds_ibi_acc, \"IBI + ACC input keys: ibi_sequence, accelerometer\")\n", + "\n", + "if ds_ibi_acc.samples:\n", + " acc_shape = ds_ibi_acc.samples[0][\"accelerometer\"].shape\n", + " print(f\"\\nACC tensor shape per epoch: {acc_shape} (rows × 3 axes)\")\n", + "\n", + "print(\n", + " \"\\nTo train: replace feature_keys=['ibi_sequence'] with \"\n", + " \"['ibi_sequence', 'accelerometer'] and compare Wake F1.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "445ae85d", + "metadata": {}, + "source": [ + "---\n", + "## Ablation 3 — Epoch Duration: 15 s / 30 s / 60 s\n", + "\n", + "The paper fixes each epoch at **30 seconds** (the PSG standard). \n", + "We test shorter (15 s) and longer (60 s) windows to explore the tradeoff between temporal resolution and per-epoch IBI context.\n", + "\n", + "**Hypothesis:** shorter windows increase epoch count and temporal resolution but give the model fewer heartbeats per sample; longer windows provide richer IBI context but may blur stage transitions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c02045b", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{'Epoch (s)':<10} {'Total epochs':<15} {'Avg IBI vals/epoch':<20}\")\n", + "print(\"-\" * 45)\n", + "\n", + "for epoch_secs in (15, 30, 60):\n", + " task_ep = SleepStagingDREAMT(epoch_seconds=epoch_secs, num_classes=3)\n", + " ds_ep = dreamt.set_task(task_ep)\n", + " n = len(ds_ep.samples)\n", + " avg_ibi = (\n", + " np.mean([len(s[\"ibi_sequence\"]) for s in ds_ep.samples])\n", + " if ds_ep.samples else 0.0\n", + " )\n", + " paper_marker = \" ← paper default\" if epoch_secs == 30 else \"\"\n", + " print(f\"{epoch_secs:<10} {n:<15} {avg_ibi:<20.1f}{paper_marker}\")\n", + "\n", + "print(\n", + " \"\\nObservation: halving epoch duration doubles epoch count \"\n", + " \"but halves the average IBI count per window.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2e0c836b", + "metadata": {}, + "source": [ + "---\n", + "## Step 2 — Train a lightweight RNN on the 3-class task\n", + "\n", + "We use PyHealth's built-in **RNN** model as a stand-in for the WatchSleepNet encoder,\n", + "applied to the variable-length IBI sequence of each epoch. \n", + "This validates the full data → task → model → evaluation pipeline end-to-end." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "340666f7", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from pyhealth.datasets import get_dataloader, split_by_patient\n", + " from pyhealth.models import RNN\n", + " from pyhealth.trainer import Trainer\n", + "\n", + " train_ds, val_ds, test_ds = split_by_patient(ds_3cls, [0.7, 0.15, 0.15])\n", + " train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)\n", + " val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n", + " test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n", + "\n", + " print(f\"Split — train: {len(train_ds)} val: {len(val_ds)} test: {len(test_ds)} epochs\")\n", + "\n", + " model = RNN(\n", + " dataset=ds_3cls,\n", + " feature_keys=[\"ibi_sequence\"],\n", + " label_key=\"label\",\n", + " mode=\"multiclass\",\n", + " )\n", + "\n", + " trainer = Trainer(model=model, device=\"cpu\")\n", + " trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " epochs=3,\n", + " monitor=\"accuracy\",\n", + " )\n", + "\n", + " results = trainer.evaluate(test_loader)\n", + " print(f\"\\nTest results: {results}\")\n", + "\n", + "except Exception as exc:\n", + " print(f\"[skipped] Model training requires additional dependencies: {exc}\")" + ] + }, + { + "cell_type": "markdown", + "id": "40f168c0", + "metadata": {}, + "source": [ + "---\n", + "## Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a386d93", + "metadata": {}, + "outputs": [], + "source": [ + "if _demo_tmpdir and os.path.isdir(_demo_tmpdir):\n", + " shutil.rmtree(_demo_tmpdir)\n", + " print(f\"[demo] Cleaned up {_demo_tmpdir}\")\n", + "\n", + "print(\"Done.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..958916865 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -56,6 +56,7 @@ sleep_staging_shhs_fn, sleep_staging_sleepedf_fn, ) +from .dreamt_sleep_classification import DREAMTSleepClassification from .sleep_staging_v2 import SleepStagingSleepEDF from .temple_university_EEG_tasks import ( EEGEventsTUEV, diff --git a/pyhealth/tasks/dreamt_sleep_classification.py b/pyhealth/tasks/dreamt_sleep_classification.py new file mode 100644 index 000000000..768910f96 --- /dev/null +++ b/pyhealth/tasks/dreamt_sleep_classification.py @@ -0,0 +1,224 @@ +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd + +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + +# 3-class: Wake / NREM (N1+N2+N3 merged) / REM +_LABEL_MAP_3CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 1, + "N3": 1, + "R": 2, +} + +# 4-class: Wake / N1 / N2 / N3 / REM (5 distinct stages, labelled 0-4) +_LABEL_MAP_4CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 2, + "N3": 3, + "R": 4, +} + + +class DREAMTSleepClassification(BaseTask): + """IBI-based sleep staging task for the DREAMT wearable dataset. + + Implements the preprocessing pipeline from WatchSleepNet (Wang et al., 2025), + which performs sequence-to-sequence sleep stage classification using + Inter-Beat Interval (IBI) signals derived from wrist-worn PPG sensors. + + Each patient's full-night 64 Hz recording is segmented into non-overlapping + ``epoch_seconds``-second windows. For each window, the non-zero IBI values + are extracted as the input signal and a majority-vote sleep stage label is + assigned. + + Two label configurations are supported: + + - **3-class** (default, matches the paper): Wake (0), NREM (1), REM (2). + N1, N2, and N3 are merged into a single NREM class. + - **4-class** (ablation extension): Wake (0), N1 (1), N2 (2), N3 (3), + REM (4). This provides finer clinical granularity beyond the paper. + + An optional accelerometer ablation (``use_accelerometer=True``) appends the + raw ACC_X/ACC_Y/ACC_Z time series to each sample for wake-detection studies. + + Attributes: + task_name (str): ``"DREAMTSleepClassification"``. + input_schema (Dict[str, str]): ``{"ibi_sequence": "tensor"}`` by default; + ``{"ibi_sequence": "tensor", "accelerometer": "tensor"}`` when + ``use_accelerometer=True``. + output_schema (Dict[str, str]): ``{"label": "multiclass"}``. + + References: + Wang et al., "WatchSleepNet: A Novel Model and Pretraining Approach for + Advancing Sleep Staging with Smartwatches", 2025. + https://doi.org/10.48550/arXiv.2501.17268 + """ + + task_name: str = "DREAMTSleepClassification" + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __init__( + self, + epoch_seconds: int = 30, + num_classes: int = 3, + use_accelerometer: bool = False, + sample_rate: int = 64, + ) -> None: + """Initializes the DREAMTSleepClassification task. + + Args: + epoch_seconds: Duration of each epoch window in seconds. Default 30. + num_classes: Number of sleep stage classes. Use 3 for Wake/NREM/REM + (paper default) or 4 for Wake/N1/N2/N3/REM (ablation). Default 3. + use_accelerometer: If True, include raw ACC_X/ACC_Y/ACC_Z signals + alongside IBI in each sample (ablation for wake detection). + Default False. + sample_rate: Sampling rate (Hz) of the DREAMT data file used for + epoch windowing. Default 64. + + Raises: + ValueError: If ``num_classes`` is not 3 or 4. + """ + if num_classes not in (3, 4): + raise ValueError( + f"num_classes must be 3 (Wake/NREM/REM) or 4 (Wake/N1/N2/N3/REM)," + f" got {num_classes}." + ) + self.epoch_seconds = epoch_seconds + self.num_classes = num_classes + self.use_accelerometer = use_accelerometer + self.sample_rate = sample_rate + self.label_map: Dict[str, int] = ( + _LABEL_MAP_3CLASS if num_classes == 3 else _LABEL_MAP_4CLASS + ) + self.input_schema: Dict[str, str] = ( + {"ibi_sequence": "tensor", "accelerometer": "tensor"} + if use_accelerometer + else {"ibi_sequence": "tensor"} + ) + super().__init__() + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Segments a patient's full-night recording into labeled IBI epochs. + + Reads the patient's 64 Hz DREAMT CSV file, creates non-overlapping + ``epoch_seconds``-second windows, and for each window extracts the + IBI sequence and majority-vote sleep stage label. + + Args: + patient: A patient object returned by ``DREAMTDataset``. Must + expose a ``get_events(event_type="dreamt_sleep")`` method whose + event objects carry a ``file_64hz`` attribute. + + Returns: + A list of sample dicts. Each dict contains: + + - ``patient_id`` (str): Patient identifier (e.g., ``"S001"``). + - ``epoch_idx`` (int): Zero-based index of the epoch within the + recording. + - ``ibi_sequence`` (np.ndarray, float32): IBI values (in seconds) + that fell within the epoch. Length varies with heart rate. + - ``label`` (int): Sleep stage class index. + - ``accelerometer`` (np.ndarray, float32, shape + ``(samples_per_epoch, 3)``): Raw ACC_X/Y/Z signals. **Only + present when** ``use_accelerometer=True``. + + Examples: + >>> from pyhealth.datasets import DREAMTDataset + >>> from pyhealth.tasks import DREAMTSleepClassification + >>> dreamt = DREAMTDataset(root="/path/to/dreamt/2.1.0") + >>> task = DREAMTSleepClassification(num_classes=3) + >>> dataset = dreamt.set_task(task) + >>> dataset[0] + { + 'patient_id': 'S001', + 'epoch_idx': 0, + 'ibi_sequence': array([0.84, 0.81, ...], dtype=float32), + 'label': 1 + } + """ + samples_per_epoch = self.sample_rate * self.epoch_seconds + samples: List[Dict[str, Any]] = [] + + events = patient.get_events(event_type="dreamt_sleep") + for event in events: + if event.file_64hz is None: + logger.warning( + "Patient %s has no 64 Hz file; skipping.", patient.patient_id + ) + continue + + try: + df = pd.read_csv(str(event.file_64hz)) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Could not read %s for patient %s: %s", + event.file_64hz, + patient.patient_id, + exc, + ) + continue + + if not {"IBI", "Sleep_Stage"}.issubset(df.columns): + logger.warning( + "Patient %s: required columns missing from %s; skipping.", + patient.patient_id, + event.file_64hz, + ) + continue + + n_epochs = len(df) // samples_per_epoch + for epoch_idx in range(n_epochs): + start = epoch_idx * samples_per_epoch + end = start + samples_per_epoch + epoch_df = df.iloc[start:end] + + # Determine label by majority vote across the epoch's rows. + stage_counts = epoch_df["Sleep_Stage"].value_counts() + if stage_counts.empty: + continue + majority_stage = stage_counts.index[0] + if majority_stage not in self.label_map: + continue + label = self.label_map[majority_stage] + + # Extract IBI: only rows where a heartbeat was detected. + ibi_mask = epoch_df["IBI"].notna() & (epoch_df["IBI"] > 0) + ibi_values = epoch_df.loc[ibi_mask, "IBI"].to_numpy( + dtype=np.float32 + ) + if ibi_values.size == 0: + continue + + sample: Dict[str, Any] = { + "patient_id": patient.patient_id, + "epoch_idx": epoch_idx, + "ibi_sequence": ibi_values, + "label": label, + } + + if self.use_accelerometer: + acc_cols = ["ACC_X", "ACC_Y", "ACC_Z"] + if not set(acc_cols).issubset(df.columns): + logger.warning( + "Patient %s: accelerometer columns missing at " + "epoch %d; skipping.", + patient.patient_id, + epoch_idx, + ) + continue + sample["accelerometer"] = epoch_df[acc_cols].to_numpy( + dtype=np.float32 + ) + + samples.append(sample) + + return samples diff --git a/tests/core/test_dreamt.py b/tests/core/test_dreamt.py index c99b5542a..5a06c8919 100644 --- a/tests/core/test_dreamt.py +++ b/tests/core/test_dreamt.py @@ -1,23 +1,30 @@ -import unittest -import tempfile import shutil +import tempfile +import unittest +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + import numpy as np import pandas as pd -from pathlib import Path from pyhealth.datasets import DREAMTDataset +from pyhealth.tasks.dreamt_sleep_classification import DREAMTSleepClassification + + +# --------------------------------------------------------------------------- +# Dataset tests +# --------------------------------------------------------------------------- + class TestDREAMTDatasetNewerVersions(unittest.TestCase): """Test DREAMT dataset containing 64Hz and 100Hz folders with local test data.""" - + def setUp(self): - """Set up participant info csv and 64Hz 100Hz files""" self.temp_dir = tempfile.mkdtemp() self.root = Path(self.temp_dir) - (self.root / "data_64Hz").mkdir() (self.root / "data_100Hz").mkdir() - self.num_patients = 5 patient_data = { 'SID': [f"S{i:03d}" for i in range(1, self.num_patients + 1)], @@ -31,71 +38,290 @@ def setUp(self): 'MEDICAL_HISTORY': ['Medical History'] * self.num_patients, 'Sleep_Disorders': ['Sleep Disorder'] * self.num_patients, } - - patient_data_df = pd.DataFrame(patient_data) - patient_data_df.to_csv(self.root / "participant_info.csv", index=False) + pd.DataFrame(patient_data).to_csv(self.root / "participant_info.csv", index=False) self._create_files() def _create_files(self): - """Create 64Hz and 100Hz files""" for i in range(1, self.num_patients + 1): sid = f"S{i:03d}" - partial_data = { - 'TIMESTAMP': [np.random.uniform(0, 100)], - 'BVP': [np.random.uniform(1, 10)], - 'HR': [np.random.randint(15, 100)], - 'EDA': [np.random.uniform(0, 1)], - 'TEMP': [np.random.uniform(20, 30)], - 'ACC_X': [np.random.uniform(1, 50)], - 'ACC_Y': [np.random.uniform(1, 50)], - 'ACC_Z': [np.random.uniform(1, 50)], - 'IBI': [np.random.uniform(0.6, 1.2)], - 'Sleep_Stage': [np.random.choice(['W', 'N1', 'N2', 'N3', 'R'])], + 'TIMESTAMP': [np.random.uniform(0, 100)], + 'BVP': [np.random.uniform(1, 10)], + 'HR': [np.random.randint(15, 100)], + 'EDA': [np.random.uniform(0, 1)], + 'TEMP': [np.random.uniform(20, 30)], + 'ACC_X': [np.random.uniform(1, 50)], + 'ACC_Y': [np.random.uniform(1, 50)], + 'ACC_Z': [np.random.uniform(1, 50)], + 'IBI': [np.random.uniform(0.6, 1.2)], + 'Sleep_Stage': [np.random.choice(['W', 'N1', 'N2', 'N3', 'R'])], } - pd.DataFrame(partial_data).to_csv(self.root / "data_64Hz" / f"{sid}_whole_df.csv") pd.DataFrame(partial_data).to_csv(self.root / "data_100Hz" / f"{sid}_PSG_df.csv") - + def tearDown(self): shutil.rmtree(self.temp_dir) def test_dataset_initialization(self): - """Test DREAMTDataset initialization""" dataset = DREAMTDataset(root=str(self.root)) - self.assertIsNotNone(dataset) self.assertEqual(dataset.dataset_name, "dreamt_sleep") self.assertEqual(dataset.root, str(self.root)) def test_metadata_file_created(self): - """Test dreamt-metadata.csv created""" dataset = DREAMTDataset(root=str(self.root)) - metadata_file = self.root / "dreamt-metadata.csv" - self.assertTrue(metadata_file.exists()) + self.assertTrue((self.root / "dreamt-metadata.csv").exists()) def test_patient_count(self): - """Test all patients are added""" dataset = DREAMTDataset(root=str(self.root)) self.assertEqual(len(dataset.unique_patient_ids), self.num_patients) - + def test_stats_method(self): - """Test stats method""" - dataset = DREAMTDataset(root=str(self.root)) - dataset.stats() + DREAMTDataset(root=str(self.root)).stats() def test_get_patient(self): - """Test get_patient method""" dataset = DREAMTDataset(root=str(self.root)) patient = dataset.get_patient('S001') self.assertIsNotNone(patient) self.assertEqual(patient.patient_id, 'S001') - + def test_get_patient_not_found(self): - """Test that patient not found throws error.""" dataset = DREAMTDataset(root=str(self.root)) with self.assertRaises(AssertionError): dataset.get_patient('S222') + +# --------------------------------------------------------------------------- +# Task stubs +# --------------------------------------------------------------------------- + + +@dataclass +class _DummyEvent: + file_64hz: Optional[str] + + +class _DummyPatient: + def __init__(self, patient_id: str, events: List[_DummyEvent]) -> None: + self.patient_id = patient_id + self._events = events + + def get_events(self, event_type: Optional[str] = None) -> List[_DummyEvent]: + return self._events + + +def _make_dreamt_csv( + path: Path, + n_rows: int = 120, + stage: str = "N2", + ibi_every: int = 4, + include_acc: bool = True, +) -> None: + rng = np.random.default_rng(42) + ibi = np.zeros(n_rows) + beat_indices = np.arange(0, n_rows, ibi_every) + ibi[beat_indices] = rng.uniform(0.7, 1.1, len(beat_indices)) + data: dict = { + "TIMESTAMP": np.arange(n_rows) / 4.0, + "BVP": rng.standard_normal(n_rows), + "HR": rng.integers(50, 90, n_rows).astype(float), + "EDA": rng.uniform(0.0, 1.0, n_rows), + "TEMP": rng.uniform(33.0, 37.0, n_rows), + "IBI": ibi, + "Sleep_Stage": [stage] * n_rows, + } + if include_acc: + data["ACC_X"] = rng.standard_normal(n_rows) + data["ACC_Y"] = rng.standard_normal(n_rows) + data["ACC_Z"] = rng.standard_normal(n_rows) + pd.DataFrame(data).to_csv(path, index=False) + + +# --------------------------------------------------------------------------- +# Task init tests +# --------------------------------------------------------------------------- + + +class TestDREAMTSleepClassificationInit(unittest.TestCase): + """Verify task attributes are set correctly at construction time.""" + + def test_default_attributes(self): + task = DREAMTSleepClassification() + self.assertEqual(task.task_name, "DREAMTSleepClassification") + self.assertEqual(task.epoch_seconds, 30) + self.assertEqual(task.num_classes, 3) + self.assertFalse(task.use_accelerometer) + self.assertEqual(task.sample_rate, 64) + + def test_output_schema(self): + self.assertEqual(DREAMTSleepClassification().output_schema, {"label": "multiclass"}) + + def test_3class_input_schema(self): + task = DREAMTSleepClassification(num_classes=3) + self.assertEqual(task.input_schema, {"ibi_sequence": "tensor"}) + self.assertNotIn("accelerometer", task.input_schema) + + def test_accelerometer_input_schema(self): + task = DREAMTSleepClassification(use_accelerometer=True) + self.assertIn("ibi_sequence", task.input_schema) + self.assertIn("accelerometer", task.input_schema) + + def test_4class_initialises(self): + self.assertEqual(DREAMTSleepClassification(num_classes=4).num_classes, 4) + + def test_invalid_num_classes_raises(self): + with self.assertRaises(ValueError): + DREAMTSleepClassification(num_classes=5) + + def test_invalid_num_classes_1_raises(self): + with self.assertRaises(ValueError): + DREAMTSleepClassification(num_classes=1) + + +# --------------------------------------------------------------------------- +# Task __call__ tests +# --------------------------------------------------------------------------- + + +class TestDREAMTSleepClassificationCall(unittest.TestCase): + """End-to-end tests of the task's __call__ method.""" + + def setUp(self) -> None: + self.tmp = tempfile.mkdtemp() + self.root = Path(self.tmp) + + def tearDown(self) -> None: + shutil.rmtree(self.tmp) + + def _patient(self, n_rows=120, stage="N2", file_exists=True, include_acc=True): + csv_path = self.root / "S001_whole_df.csv" + if file_exists: + _make_dreamt_csv(csv_path, n_rows=n_rows, stage=stage, include_acc=include_acc) + return _DummyPatient( + "S001", + [_DummyEvent(file_64hz=str(csv_path) if file_exists else None)], + ) + + def _task(self, **kwargs): + return DREAMTSleepClassification(sample_rate=4, **kwargs) + + def test_returns_one_epoch_for_exactly_one_window(self): + self.assertEqual(len(self._task()(self._patient(n_rows=120))), 1) + + def test_returns_two_epochs_for_two_windows(self): + self.assertEqual(len(self._task()(self._patient(n_rows=240))), 2) + + def test_insufficient_rows_returns_empty(self): + self.assertEqual(self._task()(self._patient(n_rows=50)), []) + + def test_sample_has_required_keys(self): + s = self._task()(self._patient(n_rows=120))[0] + for key in ("patient_id", "epoch_idx", "ibi_sequence", "label"): + self.assertIn(key, s) + + def test_patient_id_propagated(self): + self.assertEqual(self._task()(self._patient(n_rows=120))[0]["patient_id"], "S001") + + def test_epoch_idx_zero_for_first_epoch(self): + self.assertEqual(self._task()(self._patient(n_rows=120))[0]["epoch_idx"], 0) + + def test_epoch_indices_sequential(self): + samples = self._task()(self._patient(n_rows=360)) + self.assertEqual([s["epoch_idx"] for s in samples], [0, 1, 2]) + + def test_ibi_values_are_float32(self): + self.assertEqual(self._task()(self._patient(n_rows=120))[0]["ibi_sequence"].dtype, np.float32) + + def test_ibi_values_positive(self): + ibi = self._task()(self._patient(n_rows=120))[0]["ibi_sequence"] + self.assertTrue(np.all(ibi > 0)) + + def test_ibi_length_matches_beat_count(self): + _make_dreamt_csv(self.root / "S002_whole_df.csv", n_rows=120, ibi_every=4) + patient = _DummyPatient("S002", [_DummyEvent(str(self.root / "S002_whole_df.csv"))]) + self.assertEqual(len(self._task()(patient)[0]["ibi_sequence"]), 30) + + def test_3class_wake_label(self): + self.assertEqual(self._task(num_classes=3)(self._patient(stage="W"))[0]["label"], 0) + + def test_3class_n1_maps_to_nrem(self): + self.assertEqual(self._task(num_classes=3)(self._patient(stage="N1"))[0]["label"], 1) + + def test_3class_n2_maps_to_nrem(self): + self.assertEqual(self._task(num_classes=3)(self._patient(stage="N2"))[0]["label"], 1) + + def test_3class_n3_maps_to_nrem(self): + self.assertEqual(self._task(num_classes=3)(self._patient(stage="N3"))[0]["label"], 1) + + def test_3class_rem_label(self): + self.assertEqual(self._task(num_classes=3)(self._patient(stage="R"))[0]["label"], 2) + + def test_4class_wake_label(self): + self.assertEqual(self._task(num_classes=4)(self._patient(stage="W"))[0]["label"], 0) + + def test_4class_n1_label(self): + self.assertEqual(self._task(num_classes=4)(self._patient(stage="N1"))[0]["label"], 1) + + def test_4class_n2_label(self): + self.assertEqual(self._task(num_classes=4)(self._patient(stage="N2"))[0]["label"], 2) + + def test_4class_n3_label(self): + self.assertEqual(self._task(num_classes=4)(self._patient(stage="N3"))[0]["label"], 3) + + def test_4class_rem_label(self): + self.assertEqual(self._task(num_classes=4)(self._patient(stage="R"))[0]["label"], 4) + + def test_none_file_returns_empty(self): + self.assertEqual(self._task()(self._patient(file_exists=False)), []) + + def test_unknown_stage_skipped(self): + self.assertEqual(self._task()(self._patient(stage="X")), []) + + def test_no_ibi_values_skipped(self): + csv_path = self.root / "S003_whole_df.csv" + rng = np.random.default_rng(0) + n = 120 + pd.DataFrame({ + "TIMESTAMP": np.arange(n) / 4.0, + "BVP": rng.standard_normal(n), + "HR": rng.integers(50, 90, n).astype(float), + "EDA": rng.uniform(0.0, 1.0, n), + "TEMP": rng.uniform(33.0, 37.0, n), + "IBI": np.zeros(n), + "Sleep_Stage": ["N2"] * n, + }).to_csv(csv_path, index=False) + patient = _DummyPatient("S003", [_DummyEvent(str(csv_path))]) + self.assertEqual(self._task()(patient), []) + + def test_multiple_events_aggregated(self): + csv1, csv2 = self.root / "night1.csv", self.root / "night2.csv" + _make_dreamt_csv(csv1, stage="N2") + _make_dreamt_csv(csv2, stage="R") + patient = _DummyPatient("S004", [_DummyEvent(str(csv1)), _DummyEvent(str(csv2))]) + samples = self._task()(patient) + self.assertEqual(len(samples), 2) + labels = {s["label"] for s in samples} + self.assertIn(1, labels) + self.assertIn(2, labels) + + def test_accelerometer_key_absent_by_default(self): + self.assertNotIn("accelerometer", self._task()(self._patient(n_rows=120))[0]) + + def test_accelerometer_present_when_enabled(self): + samples = self._task(use_accelerometer=True)(self._patient(include_acc=True)) + self.assertIn("accelerometer", samples[0]) + + def test_accelerometer_shape(self): + samples = self._task(use_accelerometer=True)(self._patient(n_rows=120, include_acc=True)) + self.assertEqual(samples[0]["accelerometer"].shape, (120, 3)) + + def test_accelerometer_dtype_float32(self): + samples = self._task(use_accelerometer=True)(self._patient(n_rows=120, include_acc=True)) + self.assertEqual(samples[0]["accelerometer"].dtype, np.float32) + + def test_accelerometer_missing_columns_skips_epoch(self): + self.assertEqual(self._task(use_accelerometer=True)(self._patient(include_acc=False)), []) + + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 5332610a5aca009f6deb4398ce79e0a73c3a571b Mon Sep 17 00:00:00 2001 From: suraj kumarwq Date: Sun, 19 Apr 2026 17:19:06 -0400 Subject: [PATCH 2/2] imporve test runtime --- tests/core/test_dreamt.py | 90 +++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/tests/core/test_dreamt.py b/tests/core/test_dreamt.py index 5a06c8919..fd971dd66 100644 --- a/tests/core/test_dreamt.py +++ b/tests/core/test_dreamt.py @@ -20,75 +20,71 @@ class TestDREAMTDatasetNewerVersions(unittest.TestCase): """Test DREAMT dataset containing 64Hz and 100Hz folders with local test data.""" - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.root = Path(self.temp_dir) - (self.root / "data_64Hz").mkdir() - (self.root / "data_100Hz").mkdir() - self.num_patients = 5 + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.root = Path(cls.temp_dir) + (cls.root / "data_64Hz").mkdir() + (cls.root / "data_100Hz").mkdir() + cls.num_patients = 5 + rng = np.random.default_rng(42) patient_data = { - 'SID': [f"S{i:03d}" for i in range(1, self.num_patients + 1)], - 'AGE': np.random.uniform(25, 65, self.num_patients), - 'GENDER': np.random.choice(['M', 'F'], self.num_patients), - 'BMI': np.random.randint(20, 50, self.num_patients), - 'OAHI': np.random.randint(0, 50, self.num_patients), - 'AHI': np.random.randint(0, 50, self.num_patients), - 'Mean_SaO2': [f"{val}%" for val in np.random.randint(85, 99, self.num_patients)], - 'Arousal Index': np.random.randint(1, 100, self.num_patients), - 'MEDICAL_HISTORY': ['Medical History'] * self.num_patients, - 'Sleep_Disorders': ['Sleep Disorder'] * self.num_patients, + 'SID': [f"S{i:03d}" for i in range(1, cls.num_patients + 1)], + 'AGE': rng.uniform(25, 65, cls.num_patients), + 'GENDER': rng.choice(['M', 'F'], cls.num_patients), + 'BMI': rng.integers(20, 50, cls.num_patients), + 'OAHI': rng.integers(0, 50, cls.num_patients), + 'AHI': rng.integers(0, 50, cls.num_patients), + 'Mean_SaO2': [f"{v}%" for v in rng.integers(85, 99, cls.num_patients)], + 'Arousal Index': rng.integers(1, 100, cls.num_patients), + 'MEDICAL_HISTORY': ['Medical History'] * cls.num_patients, + 'Sleep_Disorders': ['Sleep Disorder'] * cls.num_patients, } - pd.DataFrame(patient_data).to_csv(self.root / "participant_info.csv", index=False) - self._create_files() - - def _create_files(self): - for i in range(1, self.num_patients + 1): + pd.DataFrame(patient_data).to_csv(cls.root / "participant_info.csv", index=False) + for i in range(1, cls.num_patients + 1): sid = f"S{i:03d}" partial_data = { - 'TIMESTAMP': [np.random.uniform(0, 100)], - 'BVP': [np.random.uniform(1, 10)], - 'HR': [np.random.randint(15, 100)], - 'EDA': [np.random.uniform(0, 1)], - 'TEMP': [np.random.uniform(20, 30)], - 'ACC_X': [np.random.uniform(1, 50)], - 'ACC_Y': [np.random.uniform(1, 50)], - 'ACC_Z': [np.random.uniform(1, 50)], - 'IBI': [np.random.uniform(0.6, 1.2)], - 'Sleep_Stage': [np.random.choice(['W', 'N1', 'N2', 'N3', 'R'])], + 'TIMESTAMP': [rng.uniform(0, 100)], + 'BVP': [rng.uniform(1, 10)], + 'HR': [int(rng.integers(15, 100))], + 'EDA': [rng.uniform(0, 1)], + 'TEMP': [rng.uniform(20, 30)], + 'ACC_X': [rng.uniform(1, 50)], + 'ACC_Y': [rng.uniform(1, 50)], + 'ACC_Z': [rng.uniform(1, 50)], + 'IBI': [rng.uniform(0.6, 1.2)], + 'Sleep_Stage': [rng.choice(['W', 'N1', 'N2', 'N3', 'R'])], } - pd.DataFrame(partial_data).to_csv(self.root / "data_64Hz" / f"{sid}_whole_df.csv") - pd.DataFrame(partial_data).to_csv(self.root / "data_100Hz" / f"{sid}_PSG_df.csv") + pd.DataFrame(partial_data).to_csv(cls.root / "data_64Hz" / f"{sid}_whole_df.csv") + pd.DataFrame(partial_data).to_csv(cls.root / "data_100Hz" / f"{sid}_PSG_df.csv") + cls.dataset = DREAMTDataset(root=str(cls.root)) - def tearDown(self): - shutil.rmtree(self.temp_dir) + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.temp_dir) def test_dataset_initialization(self): - dataset = DREAMTDataset(root=str(self.root)) - self.assertIsNotNone(dataset) - self.assertEqual(dataset.dataset_name, "dreamt_sleep") - self.assertEqual(dataset.root, str(self.root)) + self.assertIsNotNone(self.dataset) + self.assertEqual(self.dataset.dataset_name, "dreamt_sleep") + self.assertEqual(self.dataset.root, str(self.root)) def test_metadata_file_created(self): - dataset = DREAMTDataset(root=str(self.root)) self.assertTrue((self.root / "dreamt-metadata.csv").exists()) def test_patient_count(self): - dataset = DREAMTDataset(root=str(self.root)) - self.assertEqual(len(dataset.unique_patient_ids), self.num_patients) + self.assertEqual(len(self.dataset.unique_patient_ids), self.num_patients) def test_stats_method(self): - DREAMTDataset(root=str(self.root)).stats() + self.dataset.stats() def test_get_patient(self): - dataset = DREAMTDataset(root=str(self.root)) - patient = dataset.get_patient('S001') + patient = self.dataset.get_patient('S001') self.assertIsNotNone(patient) self.assertEqual(patient.patient_id, 'S001') def test_get_patient_not_found(self): - dataset = DREAMTDataset(root=str(self.root)) with self.assertRaises(AssertionError): - dataset.get_patient('S222') + self.dataset.get_patient('S222') # ---------------------------------------------------------------------------