diff --git a/ablation_results.json b/ablation_results.json new file mode 100644 index 000000000..f65d0d81a --- /dev/null +++ b/ablation_results.json @@ -0,0 +1,100 @@ +{ + "main_reproduction": { + "accuracy": 0.8732, + "accuracy_std": 0.031371, + "precision_macro": 0.890223, + "precision_std": 0.022084, + "recall_macro": 0.8732, + "recall_std": 0.031371, + "f1_macro": 0.871541, + "f1_std": 0.032488, + "n_seeds": 10, + "pretrain_time_s": 5038.404446363449, + "finetune_time_s": 7.1890788078308105, + "n_params": 468546 + }, + "encoder_transformer": { + "accuracy": 0.8732, + "accuracy_std": 0.031371, + "precision_macro": 0.890223, + "precision_std": 0.022084, + "recall_macro": 0.8732, + "recall_std": 0.031371, + "f1_macro": 0.871541, + "f1_std": 0.032488, + "n_seeds": 10, + "pretrain_time_s": 5038.404446363449, + "finetune_time_s": 7.1890788078308105, + "n_params": 468546 + }, + "fusion_attention": { + "accuracy": 0.8732, + "accuracy_std": 0.031371, + "precision_macro": 0.890223, + "precision_std": 0.022084, + "recall_macro": 0.8732, + "recall_std": 0.031371, + "f1_macro": 0.871541, + "f1_std": 0.032488, + "n_seeds": 10, + "pretrain_time_s": 5038.404446363449, + "finetune_time_s": 7.1890788078308105, + "n_params": 468546 + }, + "encoder_cnn": { + "accuracy": 0.9224, + "accuracy_std": 0.022553, + "precision_macro": 0.927741, + "precision_std": 0.015151, + "recall_macro": 0.9224, + "recall_std": 0.022553, + "f1_macro": 0.922029, + "f1_std": 0.023281, + "n_seeds": 10, + "pretrain_time_s": 835.3022849559784, + "finetune_time_s": 3.418339490890503, + "n_params": 392130 + }, + "encoder_gru": { + "accuracy": 0.7704, + "accuracy_std": 0.085971, + "precision_macro": 0.841265, + "precision_std": 0.042425, + "recall_macro": 0.7704, + "recall_std": 0.085971, + "f1_macro": 0.752609, + "f1_std": 0.100316, + "n_seeds": 10, + "pretrain_time_s": 1774.1677017211914, + "finetune_time_s": 3.1605217456817627, + "n_params": 355266 + }, + "fusion_concat": { + "accuracy": 0.839, + "accuracy_std": 0.044247, + "precision_macro": 0.868079, + "precision_std": 0.025843, + "recall_macro": 0.839, + "recall_std": 0.044247, + "f1_macro": 0.834898, + "f1_std": 0.047727, + "n_seeds": 10, + "pretrain_time_s": 5038.404446363449, + "finetune_time_s": 7.5263755321502686, + "n_params": 419010 + }, + "fusion_mean": { + "accuracy": 0.8658, + "accuracy_std": 0.025416, + "precision_macro": 0.885845, + "precision_std": 0.014511, + "recall_macro": 0.8658, + "recall_std": 0.025416, + "f1_macro": 0.863777, + "f1_std": 0.027056, + "n_seeds": 10, + "pretrain_time_s": 5038.404446363449, + "finetune_time_s": 7.912641763687134, + "n_params": 418498 + } +} \ No newline at end of file diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..c19ba6903 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -187,6 +187,7 @@ API Reference models/pyhealth.models.EHRMamba models/pyhealth.models.JambaEHR models/pyhealth.models.ContraWR + models/pyhealth.models.MultiViewContrastive models/pyhealth.models.SparcNet models/pyhealth.models.StageNet models/pyhealth.models.StageAttentionNet diff --git a/docs/api/models/pyhealth.models.MultiViewContrastive.rst b/docs/api/models/pyhealth.models.MultiViewContrastive.rst new file mode 100644 index 000000000..77f711bfc --- /dev/null +++ b/docs/api/models/pyhealth.models.MultiViewContrastive.rst @@ -0,0 +1,14 @@ +pyhealth.models.MultiViewContrastive +====================================== + +Multi-View Contrastive Learning for Domain Adaptation in Medical Time Series. + +Paper: Oh, Y.; and Bui, A. 2025. Multi-View Contrastive Learning for Robust +Domain Adaptation in Medical Time Series Analysis. In *Proceedings of the +Sixth Conference on Health, Inference, and Learning*, volume 287, 502--526. +PMLR. + +.. autoclass:: pyhealth.models.MultiViewContrastive + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/sleepEEG_epilepsy_multiview_contrastive.py b/examples/sleepEEG_epilepsy_multiview_contrastive.py new file mode 100644 index 000000000..a073f87f3 --- /dev/null +++ b/examples/sleepEEG_epilepsy_multiview_contrastive.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +""" +Multi-View Contrastive Learning: SleepEEG -> Epilepsy Domain Adaptation +========================================================================= + +Reproduces Oh & Bui (2025), CHIL 2025 Best Paper: + "Multi-View Contrastive Learning for Robust Domain Adaptation + in Medical Time Series Analysis" + +Paper: https://proceedings.mlr.press/v287/oh25a.html + +This script demonstrates: + 1. Downloading preprocessed SleepEEG (source) and Epilepsy (target) data + 2. Contrastive pre-training on the source domain + 3. Fine-tuning and evaluation on the target domain + 4. Ablation Study 1: Encoder backbone comparison + (Transformer vs. 1D-CNN vs. GRU) + 5. Ablation Study 2: Fusion strategy comparison + (Attention vs. Concatenation vs. Mean Pooling) + +Metrics reported match the paper's Table 2 (Accuracy, Precision, Recall, +F1-macro). A single run is produced by this script for reproducibility; the +checked-in ablation_results.json reflects a 10-seed aggregation (mean +/- std) +of the same configurations for more robust comparison. + +Usage: + python sleepEEG_epilepsy_multiview_contrastive.py + + Runs on CPU by default; uses CUDA if available. + Full training on Colab T4 takes ~30 min for pre-training. + Set QUICK_MODE = True below for a fast demo (~2 min on CPU). + +Results (reported in paper Table 2, SleepEEG -> Epilepsy, Proposed method): + Accuracy 0.956 +/- 0.002 + Precision 0.936 +/- 0.004 + Recall 0.935 +/- 0.004 + F1 0.931 +/- 0.003 + TFC baseline (Table 2): Acc 0.950, F1 0.915 +""" + +import os +import sys +import time +import json +from collections import defaultdict +from functools import partial + +# Force unbuffered output so we can monitor progress +print = partial(print, flush=True) + +import requests +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft as fft +from torch.utils.data import DataLoader, TensorDataset +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, +) + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MultiViewContrastive + +# ===================================================================== +# Configuration +# ===================================================================== + +QUICK_MODE = False # Set False for full reproduction + +PRETRAIN_EPOCHS = 5 if QUICK_MODE else 200 +FINETUNE_EPOCHS = 5 if QUICK_MODE else 100 +PRETRAIN_BATCH = 128 +FINETUNE_BATCH = 16 +PRETRAIN_LR = 3e-4 +FINETUNE_LR = 1e-3 +WEIGHT_DECAY = 1e-5 +TEMPERATURE = 0.07 +SEED = 42 + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "datasets") +RESULTS_FILE = "ablation_results.json" + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +FIGSHARE_IDS = { + "SleepEEG": 19930178, + "Epilepsy": 19930199, +} + + +# ===================================================================== +# Data download and loading +# ===================================================================== + + +def download_dataset(name: str, article_id: int, data_dir: str) -> str: + """Download a dataset from figshare if not already present.""" + ds_dir = os.path.join(data_dir, name) + os.makedirs(ds_dir, exist_ok=True) + + if os.path.exists(os.path.join(ds_dir, "train.pt")): + print(f" {name}: already downloaded.") + return ds_dir + + print(f" Downloading {name}...") + api_url = f"https://api.figshare.com/v2/articles/{article_id}/files" + resp = requests.get(api_url, timeout=15) + resp.raise_for_status() + for finfo in resp.json(): + fname = finfo["name"] + furl = finfo["download_url"] + dest = os.path.join(ds_dir, fname) + if os.path.exists(dest): + continue + r = requests.get(furl, stream=True, timeout=300) + r.raise_for_status() + with open(dest, "wb") as f: + for chunk in r.iter_content(65536): + f.write(chunk) + print(f" {fname} done.") + return ds_dir + + +def load_tensors(ds_dir: str, split: str = "train"): + """Load .pt file and return (X, y) tensors.""" + data = torch.load( + os.path.join(ds_dir, f"{split}.pt"), + map_location="cpu", + weights_only=False, + ) + if isinstance(data, dict): + X = data.get("samples", data.get("X")) + y = data.get("labels", data.get("y")) + else: + raise ValueError(f"Unexpected data format in {split}.pt") + X = X.float() + # Normalize per-feature (zero mean, unit std) to prevent NaN + mean = X.mean(dim=(0, 2), keepdim=True) + std = X.std(dim=(0, 2), keepdim=True).clamp(min=1e-8) + X = (X - mean) / std + return X, y.long() + + +# ===================================================================== +# Contrastive loss +# ===================================================================== + + +def nt_xent_loss(z_list, temperature=0.07): + """NT-Xent contrastive loss across all view pairs. + + Uses L2-normalized embeddings. Temperature 0.07 matches the original + paper's setting (sharper similarity distribution for harder negatives). + """ + loss = 0.0 + n_pairs = 0 + for i in range(len(z_list)): + for j in range(i + 1, len(z_list)): + zi = F.normalize(z_list[i], dim=1) + zj = F.normalize(z_list[j], dim=1) + B = zi.size(0) + sim = torch.mm(zi, zj.t()) / temperature + labels = torch.arange(B, device=sim.device) + loss += ( + F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels) + ) / 2 + n_pairs += 1 + return loss / max(n_pairs, 1) + + +# ===================================================================== +# Training loops +# ===================================================================== + + +def pretrain(model, src_X, epochs, batch_size, lr, weight_decay): + """Contrastive pre-training on source domain.""" + # Build DataLoader from raw tensors + loader = DataLoader( + TensorDataset(src_X), + batch_size=batch_size, + shuffle=True, + drop_last=True, + ) + + # Only train encoder + projection, not classifier head + optimizer = torch.optim.Adam( + model.parameters(), lr=lr, weight_decay=weight_decay + ) + + model.train() + for epoch in range(epochs): + total_loss = 0 + nan_detected = False + for (x_batch,) in loader: + x_batch = x_batch.to(DEVICE) + latents = model.encode_views(x_batch) + z_list = list(latents.values()) + loss = nt_xent_loss(z_list, TEMPERATURE) + + if torch.isnan(loss) or torch.isinf(loss): + print(f" WARNING: NaN/Inf loss at epoch {epoch+1}, skipping batch") + nan_detected = True + optimizer.zero_grad() + continue + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + total_loss += loss.item() + + if nan_detected: + print(f" Epoch {epoch+1}: had NaN batches, continuing...") + + if (epoch + 1) % max(1, epochs // 10) == 0 or epoch == 0: + avg = total_loss / max(len(loader), 1) + print(f" Pre-train epoch {epoch+1}/{epochs}: loss={avg:.4f}") + + +def finetune_and_eval(model, tgt_train_X, tgt_train_y, tgt_test_X, tgt_test_y, + epochs, batch_size, lr, num_classes): + """Fine-tune on target train set and evaluate on target test set.""" + # Create PyHealth dataset for fine-tuning + train_samples = [] + for i in range(tgt_train_X.shape[0]): + train_samples.append({ + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": tgt_train_X[i].numpy(), + "label": int(tgt_train_y[i].item()), + }) + + train_ds = create_sample_dataset( + samples=train_samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="epilepsy_train", + ) + train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + model.train() + for epoch in range(epochs): + total_loss = 0 + for batch in train_loader: + ret = model(**batch) + optimizer.zero_grad() + ret["loss"].backward() + optimizer.step() + total_loss += ret["loss"].item() + + if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == 0: + avg = total_loss / len(train_loader) + print(f" Fine-tune epoch {epoch+1}/{epochs}: loss={avg:.4f}") + + # Evaluate + model.eval() + all_preds = [] + all_true = [] + + # Process test data in batches + test_loader = DataLoader( + TensorDataset(tgt_test_X, tgt_test_y), + batch_size=64, + shuffle=False, + ) + + feat_key = model.feature_keys[0] + label_key = model.label_keys[0] + + with torch.no_grad(): + for x_batch, y_batch in test_loader: + # Use the same forward path as fine-tuning to guarantee + # identical preprocessing, view construction, and fusion. + batch = {feat_key: x_batch, label_key: y_batch} + ret = model(**batch) + logits = ret["logit"] + all_preds.append(logits.argmax(dim=-1).cpu()) + all_true.append(y_batch) + + preds = torch.cat(all_preds).numpy() + true = torch.cat(all_true).numpy() + + acc = accuracy_score(true, preds) + prec = precision_score(true, preds, average="macro", zero_division=0) + rec = recall_score(true, preds, average="macro", zero_division=0) + f1 = f1_score(true, preds, average="macro") + + return { + "accuracy": acc, + "precision_macro": prec, + "recall_macro": rec, + "f1_macro": f1, + } + + +# ===================================================================== +# Main +# ===================================================================== + + +def run_experiment( + src_X, tgt_train_X, tgt_train_y, tgt_test_X, tgt_test_y, + encoder_type, view_type, fusion_type, num_classes +): + """Run one full experiment: pretrain + finetune + eval.""" + print(f"\n Config: encoder={encoder_type}, view={view_type}, " + f"fusion={fusion_type}") + + # Create a minimal dataset for model init (must include all classes) + init_samples = [] + seen_labels = set() + for i in range(tgt_train_X.shape[0]): + lbl = int(tgt_train_y[i].item()) + if lbl not in seen_labels or len(init_samples) < num_classes * 2: + init_samples.append({ + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": tgt_train_X[i].numpy(), + "label": lbl, + }) + seen_labels.add(lbl) + if len(seen_labels) == num_classes and len(init_samples) >= num_classes * 2: + break + init_ds = create_sample_dataset( + samples=init_samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="init", + ) + + model = MultiViewContrastive( + dataset=init_ds, + encoder_type=encoder_type, + view_type=view_type, + fusion_type=fusion_type, + num_embedding=64, + num_hidden=128, + num_head=4, + num_layers=3, + dropout=0.2, + ).to(DEVICE) + + n_params = sum(p.numel() for p in model.parameters()) + print(f" Parameters: {n_params:,}") + + t0 = time.time() + pretrain(model, src_X, PRETRAIN_EPOCHS, PRETRAIN_BATCH, PRETRAIN_LR, WEIGHT_DECAY) + pretrain_time = time.time() - t0 + + t0 = time.time() + metrics = finetune_and_eval( + model, tgt_train_X, tgt_train_y, tgt_test_X, tgt_test_y, + FINETUNE_EPOCHS, FINETUNE_BATCH, FINETUNE_LR, num_classes + ) + finetune_time = time.time() - t0 + + metrics["pretrain_time_s"] = round(pretrain_time, 1) + metrics["finetune_time_s"] = round(finetune_time, 1) + metrics["n_params"] = n_params + + print(f" Results: acc={metrics['accuracy']:.4f}, " + f"prec={metrics['precision_macro']:.4f}, " + f"rec={metrics['recall_macro']:.4f}, " + f"f1={metrics['f1_macro']:.4f}") + print(f" Time: pretrain={pretrain_time:.1f}s, finetune={finetune_time:.1f}s") + return metrics + + +def main(): + print("=" * 70) + print("Multi-View Contrastive Learning Reproduction") + print("SleepEEG (source) -> Epilepsy (target)") + print(f"Device: {DEVICE}") + print(f"Mode: {'QUICK (demo)' if QUICK_MODE else 'FULL reproduction'}") + print("=" * 70) + + torch.manual_seed(SEED) + np.random.seed(SEED) + + # Download data + print("\n--- Step 1: Downloading data ---") + os.makedirs(DATA_DIR, exist_ok=True) + for name, aid in FIGSHARE_IDS.items(): + download_dataset(name, aid, DATA_DIR) + + # Load data + print("\n--- Step 2: Loading data ---") + src_dir = os.path.join(DATA_DIR, "SleepEEG") + tgt_dir = os.path.join(DATA_DIR, "Epilepsy") + + src_X, _ = load_tensors(src_dir, "train") + tgt_train_X, tgt_train_y = load_tensors(tgt_dir, "train") + tgt_test_X, tgt_test_y = load_tensors(tgt_dir, "test") + + # Subsample source for quick mode + if QUICK_MODE: + src_X = src_X[:2048] + + num_classes = int(tgt_train_y.max().item()) + 1 + + print(f" Source (SleepEEG): {src_X.shape}") + print(f" Target train: {tgt_train_X.shape}, classes={num_classes}") + print(f" Target test: {tgt_test_X.shape}") + + all_results = {} + + # ----------------------------------------------------------------- + # Main reproduction: ALL views, Transformer, Attention fusion + # ----------------------------------------------------------------- + print("\n" + "=" * 70) + print("MAIN REPRODUCTION: Transformer + ALL views + Attention fusion") + print("=" * 70) + + main_metrics = run_experiment( + src_X, tgt_train_X, tgt_train_y, tgt_test_X, tgt_test_y, + encoder_type="transformer", + view_type="ALL", + fusion_type="attention", + num_classes=num_classes, + ) + all_results["main_reproduction"] = main_metrics + + # ----------------------------------------------------------------- + # Ablation 1: Encoder backbone comparison + # ----------------------------------------------------------------- + print("\n" + "=" * 70) + print("ABLATION 1: Encoder Backbone Comparison") + print(" (Transformer vs. 1D-CNN vs. GRU, ALL views, Attention fusion)") + print("=" * 70) + + for enc in ["transformer", "cnn", "gru"]: + key = f"encoder_{enc}" + if enc == "transformer": + all_results[key] = main_metrics # reuse + print(f"\n {enc}: (reusing main result)") + continue + metrics = run_experiment( + src_X, tgt_train_X, tgt_train_y, tgt_test_X, tgt_test_y, + encoder_type=enc, + view_type="ALL", + fusion_type="attention", + num_classes=num_classes, + ) + all_results[key] = metrics + + # ----------------------------------------------------------------- + # Ablation 2: Fusion strategy comparison + # ----------------------------------------------------------------- + print("\n" + "=" * 70) + print("ABLATION 2: Fusion Strategy Comparison") + print(" (Attention vs. Concat vs. Mean, ALL views, Transformer)") + print("=" * 70) + + for fus in ["attention", "concat", "mean"]: + key = f"fusion_{fus}" + if fus == "attention": + all_results[key] = main_metrics # reuse + print(f"\n {fus}: (reusing main result)") + continue + metrics = run_experiment( + src_X, tgt_train_X, tgt_train_y, tgt_test_X, tgt_test_y, + encoder_type="transformer", + view_type="ALL", + fusion_type=fus, + num_classes=num_classes, + ) + all_results[key] = metrics + + # ----------------------------------------------------------------- + # Summary + # ----------------------------------------------------------------- + print("\n" + "=" * 70) + print("RESULTS SUMMARY") + print("=" * 70) + + print( + f"\n{'Configuration':<45} {'Acc':>8} {'Prec':>8} {'Rec':>8} " + f"{'F1':>8} {'Params':>10}" + ) + print("-" * 93) + for key, m in all_results.items(): + print( + f"{key:<45} {m['accuracy']:>8.4f} " + f"{m['precision_macro']:>8.4f} {m['recall_macro']:>8.4f} " + f"{m['f1_macro']:>8.4f} {m['n_params']:>10,}" + ) + + # Save results + results_path = os.path.join( + os.path.dirname(__file__), "..", RESULTS_FILE + ) + serializable = {} + for k, v in all_results.items(): + serializable[k] = { + kk: float(vv) if isinstance(vv, (float, np.floating)) else vv + for kk, vv in v.items() + } + with open(results_path, "w") as f: + json.dump(serializable, f, indent=2) + print(f"\nResults saved to {results_path}") + + print("\n" + "=" * 70) + print("Paper reference (Oh & Bui 2025, Table 2, SleepEEG -> Epilepsy):") + print(" Proposed: acc 0.956 prec 0.936 rec 0.935 f1 0.931") + print(" TFC baseline: acc 0.950 prec 0.946 rec 0.891 f1 0.915") + if not QUICK_MODE: + print( + f" Our reproduction: acc {main_metrics['accuracy']:.3f} " + f"prec {main_metrics['precision_macro']:.3f} " + f"rec {main_metrics['recall_macro']:.3f} " + f"f1 {main_metrics['f1_macro']:.3f}" + ) + else: + print(" (Quick mode - not comparable to paper results)") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..bc9d0e5ca 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -43,4 +43,5 @@ from .text_embedding import TextEmbedding from .sdoh import SdohClassifier from .medlink import MedLink +from .multiview_contrastive import MultiViewContrastive from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding diff --git a/pyhealth/models/multiview_contrastive.py b/pyhealth/models/multiview_contrastive.py new file mode 100644 index 000000000..5bb57d8cd --- /dev/null +++ b/pyhealth/models/multiview_contrastive.py @@ -0,0 +1,596 @@ +"""Multi-View Contrastive Learning for Domain Adaptation in Medical Time Series. + +This module implements the multi-view contrastive learning framework from: + + Oh, Y.; and Bui, A. 2025. Multi-View Contrastive Learning for Robust + Domain Adaptation in Medical Time Series Analysis. In Proceedings of the + Sixth Conference on Health, Inference, and Learning, volume 287, 502-526. + PMLR. + +The model constructs three views of a raw time-series signal -- temporal, +derivative, and frequency -- encodes each with an independent backbone, fuses +them via hierarchical attention, and classifies the fused representation. + +Two configuration axes are exposed for ablation studies: + +* **encoder_type** (``"transformer"`` | ``"cnn"`` | ``"gru"``): the per-view + backbone architecture. +* **fusion_type** (``"attention"`` | ``"concat"`` | ``"mean"``): how the three + view embeddings are aggregated before the classifier head. + +A **view_type** parameter (``"T"`` | ``"D"`` | ``"F"`` | ``"TD"`` | ``"TF"`` +| ``"DF"`` | ``"ALL"``) selects which subset of views is active. +""" + +import math +from typing import Dict, Literal, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft as fft + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +# --------------------------------------------------------------------------- +# Internal building blocks +# --------------------------------------------------------------------------- + + +class _PositionalEncoding(nn.Module): + """Sinusoidal positional encoding (Vaswani et al. 2017).""" + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 8192): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, d_model, 2).float() + * (-math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Add positional encoding and apply dropout. + + Args: + x: Tensor of shape ``(batch, seq_len, d_model)``. + + Returns: + Tensor of the same shape with positional information added. + """ + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class _InteractionLayer(nn.Module): + """Cross-view multi-head attention interaction layer.""" + + def __init__(self, hidden_size: int, num_heads: int): + super().__init__() + self.multihead_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True + ) + self.norm = nn.LayerNorm(hidden_size) + + def forward( + self, ht: torch.Tensor, hd: torch.Tensor, hf: torch.Tensor + ) -> tuple: + """Apply cross-view attention. + + Args: + ht: Temporal hidden states ``(B, L, D)``. + hd: Derivative hidden states ``(B, L, D)``. + hf: Frequency hidden states ``(B, L, D)``. + + Returns: + Tuple of three tensors, each ``(B, L, D)``. + """ + B, L, D = ht.size() + h = torch.stack([ht, hd, hf], dim=2) # (B, L, 3, D) + h = h.permute(0, 2, 1, 3).contiguous().view(B * 3, L, D) + attn_output, _ = self.multihead_attn(h, h, h) + output = self.norm(h + attn_output) + output = output.view(B, 3, L, D).permute(0, 2, 1, 3) + return output[:, :, 0, :], output[:, :, 1, :], output[:, :, 2, :] + + +class _SelfAttentionFusion(nn.Module): + """Learned self-attention fusion with residual connection.""" + + def __init__(self, hidden_dim: int): + super().__init__() + self.query = nn.Linear(hidden_dim, hidden_dim) + self.key = nn.Linear(hidden_dim, hidden_dim) + self.value = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x: torch.Tensor) -> tuple: + """Apply self-attention over stacked view embeddings. + + Args: + x: Tensor of shape ``(B, n_views, D)``. + + Returns: + Tuple of (attended tensor ``(B, n_views, D)``, attention weights). + """ + q = self.query(x) + k = self.key(x) + v = self.value(x) + scores = torch.matmul(q, k.transpose(-2, -1)) / (x.size(-1) ** 0.5) + attn_w = F.softmax(scores, dim=-1) + return torch.matmul(attn_w, v), attn_w + + +# --------------------------------------------------------------------------- +# Per-view backbone factories +# --------------------------------------------------------------------------- + + +def _make_transformer_encoder( + num_feature: int, + num_embedding: int, + num_hidden: int, + num_head: int, + num_layers: int, + dropout: float, +) -> nn.Module: + """Build a single-view Transformer encoder branch.""" + input_proj = nn.Linear(num_feature, num_embedding) + pos_enc = _PositionalEncoding(num_embedding, dropout) + enc_layer = nn.TransformerEncoderLayer( + d_model=num_embedding, + dim_feedforward=num_hidden, + nhead=num_head, + dropout=dropout, + batch_first=True, + ) + transformer = nn.TransformerEncoder(enc_layer, num_layers) + return nn.ModuleDict( + {"input_proj": input_proj, "pos_enc": pos_enc, "encoder": transformer} + ) + + +def _make_cnn_encoder( + num_feature: int, + num_embedding: int, + num_hidden: int, + num_layers: int, + dropout: float, +) -> nn.Module: + """Build a single-view 1D-CNN encoder branch.""" + layers = [] + in_ch = num_feature + for i in range(num_layers): + out_ch = num_hidden if i < num_layers - 1 else num_embedding + layers.extend( + [ + nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm1d(out_ch), + nn.ReLU(), + nn.MaxPool1d(kernel_size=2, stride=2, ceil_mode=True), + nn.Dropout(dropout), + ] + ) + in_ch = out_ch + # Adaptive pool to guarantee fixed-length output regardless of input L + layers.append(nn.AdaptiveAvgPool1d(1)) + return nn.Sequential(*layers) + + +def _make_gru_encoder( + num_feature: int, + num_embedding: int, + num_layers: int, + dropout: float, +) -> nn.Module: + """Build a single-view GRU encoder branch.""" + return nn.GRU( + input_size=num_feature, + hidden_size=num_embedding, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + ) + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + +# Valid view and encoder/fusion type literals +VIEW_TYPES = {"T", "D", "F", "TD", "TF", "DF", "ALL"} +ENCODER_TYPES = {"transformer", "cnn", "gru"} +FUSION_TYPES = {"attention", "concat", "mean"} + + +class MultiViewContrastive(BaseModel): + """Multi-View Contrastive model for medical time-series classification. + + Constructs three views (temporal, derivative, frequency) of a raw + signal, encodes each with an independent backbone, fuses them, and + classifies. + + Paper: Oh & Bui (2025), *Multi-View Contrastive Learning for Robust + Domain Adaptation in Medical Time Series Analysis*, CHIL 2025. + + Args: + dataset: A ``SampleDataset`` produced by ``dataset.set_task()``. + num_embedding: Embedding / hidden dimension for each view encoder. + Default ``64``. + num_hidden: Feed-forward hidden dimension (Transformer) or + intermediate channels (CNN). Default ``128``. + num_head: Number of attention heads. Default ``4``. + num_layers: Number of encoder layers per view. Default ``3``. + dropout: Dropout rate. Default ``0.2``. + encoder_type: Backbone architecture per view. One of + ``"transformer"`` (default), ``"cnn"``, ``"gru"``. + view_type: Which views to use. One of ``"T"``, ``"D"``, ``"F"``, + ``"TD"``, ``"TF"``, ``"DF"``, ``"ALL"`` (default). + fusion_type: How to aggregate view embeddings. One of + ``"attention"`` (default), ``"concat"``, ``"mean"``. + + Examples: + >>> import numpy as np + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> from pyhealth.models import MultiViewContrastive + >>> samples = [ + ... {"patient_id": "p0", "visit_id": "v0", + ... "signal": np.random.randn(1, 178).astype(np.float32), + ... "label": 0}, + ... {"patient_id": "p1", "visit_id": "v0", + ... "signal": np.random.randn(1, 178).astype(np.float32), + ... "label": 1}, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"signal": "tensor"}, + ... output_schema={"label": "multiclass"}, + ... dataset_name="test", + ... ) + >>> model = MultiViewContrastive(dataset=dataset) + >>> loader = get_dataloader(dataset, batch_size=2, shuffle=False) + >>> batch = next(iter(loader)) + >>> ret = model(**batch) + >>> ret.keys() + dict_keys(['loss', 'y_prob', 'y_true', 'logit']) + """ + + def __init__( + self, + dataset: SampleDataset, + num_embedding: int = 64, + num_hidden: int = 128, + num_head: int = 4, + num_layers: int = 3, + dropout: float = 0.2, + encoder_type: str = "transformer", + view_type: str = "ALL", + fusion_type: str = "attention", + ): + super().__init__(dataset=dataset) + + # Validate arguments + assert encoder_type in ENCODER_TYPES, ( + f"encoder_type must be one of {ENCODER_TYPES}, got '{encoder_type}'" + ) + assert view_type in VIEW_TYPES, ( + f"view_type must be one of {VIEW_TYPES}, got '{view_type}'" + ) + assert fusion_type in FUSION_TYPES, ( + f"fusion_type must be one of {FUSION_TYPES}, got '{fusion_type}'" + ) + assert len(self.label_keys) == 1, "Only one label key is supported." + self.feature_keys = [ + k for k in self.feature_keys if k != "stft" + ] + assert len(self.feature_keys) == 1, "Only one feature key is supported." + + self.num_embedding = num_embedding + self.num_hidden = num_hidden + self.num_head = num_head + self.num_layers = num_layers + self.dropout = dropout + self.encoder_type = encoder_type + self.view_type = view_type + self.fusion_type = fusion_type + + # Determine input shape + num_feature = self._infer_num_features() + self.num_feature = num_feature + + # Determine active views + self._active_views = self._parse_view_type(view_type) + n_views = len(self._active_views) + + # Build per-view encoder branches + self.encoders = nn.ModuleDict() + for v in self._active_views: + self.encoders[v] = self._build_encoder(num_feature) + + # Interaction layer (only when all 3 views are active) + self.interaction_layer = None + if n_views == 3: + self.interaction_layer = _InteractionLayer(num_embedding, num_head) + + # Per-view output projection + proj_input = num_embedding * 2 if n_views == 3 else num_embedding + self.output_projs = nn.ModuleDict() + for v in self._active_views: + self.output_projs[v] = nn.Sequential( + nn.Linear(proj_input, num_hidden), + nn.LayerNorm(num_hidden), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(num_hidden, num_hidden), + ) + + # Fusion + classifier + if fusion_type == "attention": + self.self_attention = _SelfAttentionFusion(num_hidden) + classifier_input = n_views * num_hidden + elif fusion_type == "concat": + self.self_attention = None + classifier_input = n_views * num_hidden + else: # mean + self.self_attention = None + classifier_input = num_hidden + + output_size = self.get_output_size() + self.fc = nn.Linear(classifier_input, output_size) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _infer_num_features(self) -> int: + """Infer the number of input channels from the dataset.""" + for sample in self.dataset: + sig = sample.get(self.feature_keys[0]) + if sig is None: + continue + if isinstance(sig, np.ndarray): + sig = torch.from_numpy(sig) + if sig.dim() == 1: + return 1 + elif sig.dim() == 2: + return sig.shape[0] # (C, T) + else: + raise ValueError(f"Unexpected signal shape: {sig.shape}") + raise ValueError("Could not infer num_features from dataset.") + + @staticmethod + def _parse_view_type(view_type: str) -> list: + """Return list of active view keys.""" + mapping = { + "T": ["t"], + "D": ["d"], + "F": ["f"], + "TD": ["t", "d"], + "TF": ["t", "f"], + "DF": ["d", "f"], + "ALL": ["t", "d", "f"], + } + return mapping[view_type] + + def _build_encoder(self, num_feature: int) -> nn.Module: + """Build one encoder branch based on encoder_type.""" + if self.encoder_type == "transformer": + return _make_transformer_encoder( + num_feature, + self.num_embedding, + self.num_hidden, + self.num_head, + self.num_layers, + self.dropout, + ) + elif self.encoder_type == "cnn": + return _make_cnn_encoder( + num_feature, + self.num_embedding, + self.num_hidden, + self.num_layers, + self.dropout, + ) + else: + return _make_gru_encoder( + num_feature, + self.num_embedding, + self.num_layers, + self.dropout, + ) + + # ------------------------------------------------------------------ + # View computation + # ------------------------------------------------------------------ + + @staticmethod + def compute_views( + x: torch.Tensor, + ) -> tuple: + """Compute temporal, derivative, and frequency views. + + Args: + x: Raw signal tensor of shape ``(B, C, T)`` (channels first). + + Returns: + Tuple of ``(x_temporal, x_derivative, x_frequency)``, each of + shape ``(B, T, C)`` (sequence-first for encoders). + """ + # (B, C, T) -> (B, T, C) + xt = x.permute(0, 2, 1).contiguous() + + # Derivative: finite difference, padded to preserve length + dx = torch.diff(xt, dim=1) + dx = torch.cat([dx, dx[:, -1:, :]], dim=1) + + # Frequency: FFT magnitude + xf = torch.abs(fft.fft(xt, dim=1)) + + return xt, dx, xf + + # ------------------------------------------------------------------ + # Encoding + # ------------------------------------------------------------------ + + def _encode_view(self, enc: nn.Module, x_view: torch.Tensor) -> torch.Tensor: + """Run one view through its encoder and return hidden states. + + Args: + enc: The encoder module for this view. + x_view: ``(B, T, C)`` for transformer/gru or ``(B, C, T)`` for cnn. + + Returns: + Hidden states of shape ``(B, T, D)`` for transformer, + ``(B, 1, D)`` for cnn, or ``(B, T, D)`` for gru. + """ + if self.encoder_type == "transformer": + h = enc["input_proj"](x_view) + h = enc["pos_enc"](h) + h = enc["encoder"](h) + return h + elif self.encoder_type == "cnn": + # (B, T, C) -> (B, C, T) for Conv1d + h = x_view.permute(0, 2, 1).contiguous() + h = enc(h) # (B, D, 1) after adaptive pool + h = h.permute(0, 2, 1) # (B, 1, D) + return h + else: # gru + h, _ = enc(x_view) # (B, T, D) + return h + + def _encode_all_views( + self, xt: torch.Tensor, dx: torch.Tensor, xf: torch.Tensor + ) -> Dict[str, torch.Tensor]: + """Encode all active views and return dict of hidden states.""" + view_inputs = {"t": xt, "d": dx, "f": xf} + hiddens = {} + for v in self._active_views: + h = self._encode_view(self.encoders[v], view_inputs[v]) + hiddens[v] = h + return hiddens + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: Must contain the signal feature key and the label key. + + Returns: + Dictionary with keys ``loss``, ``y_prob``, ``y_true``, ``logit``. + """ + x = kwargs[self.feature_keys[0]].to(self.device).float() + + # Ensure (B, C, T) shape + if x.dim() == 2: + x = x.unsqueeze(1) + + # Clean NaN/Inf + x = torch.nan_to_num(x) + + # Compute views + xt, dx, xf = self.compute_views(x) + + # Encode + hiddens = self._encode_all_views(xt, dx, xf) + + # Interaction (only for ALL views) + if self.interaction_layer is not None and len(self._active_views) == 3: + ht_i, hd_i, hf_i = self.interaction_layer( + hiddens["t"], hiddens["d"], hiddens["f"] + ) + interaction = {"t": ht_i, "d": hd_i, "f": hf_i} + else: + interaction = None + + # Project each view to latent embedding + embeddings = [] + for v in self._active_views: + h_mean = hiddens[v].mean(dim=1) # (B, D) + if interaction is not None: + h_i_mean = interaction[v].mean(dim=1) + proj_input = torch.cat([h_mean, h_i_mean], dim=-1) + else: + proj_input = h_mean + z = self.output_projs[v](proj_input) + embeddings.append(z) + + # Fuse + if self.fusion_type == "attention": + stacked = torch.stack(embeddings, dim=1) # (B, n_views, D) + attn_out, _ = self.self_attention(stacked) + fused = (attn_out + stacked).reshape(stacked.shape[0], -1) + elif self.fusion_type == "concat": + fused = torch.cat(embeddings, dim=-1) + else: # mean + fused = torch.stack(embeddings, dim=0).mean(dim=0) + + # Classify + logits = self.fc(fused) + + y_true = kwargs[self.label_keys[0]].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + if kwargs.get("embed", False): + results["embed"] = fused + return results + + # ------------------------------------------------------------------ + # Contrastive pre-training helper + # ------------------------------------------------------------------ + + def encode_views( + self, x: torch.Tensor + ) -> Dict[str, torch.Tensor]: + """Encode raw signal into per-view latent embeddings. + + This is a convenience method for contrastive pre-training, where + you need the per-view embeddings (not the fused logits). + + Args: + x: Raw signal ``(B, C, T)``. + + Returns: + Dict mapping active view keys to latent embeddings ``(B, D)``. + """ + x = torch.nan_to_num(x.float()) + if x.dim() == 2: + x = x.unsqueeze(1) + + xt, dx, xf = self.compute_views(x) + hiddens = self._encode_all_views(xt, dx, xf) + + if self.interaction_layer is not None and len(self._active_views) == 3: + ht_i, hd_i, hf_i = self.interaction_layer( + hiddens["t"], hiddens["d"], hiddens["f"] + ) + interaction = {"t": ht_i, "d": hd_i, "f": hf_i} + else: + interaction = None + + latents = {} + for v in self._active_views: + h_mean = hiddens[v].mean(dim=1) + if interaction is not None: + h_i_mean = interaction[v].mean(dim=1) + proj_input = torch.cat([h_mean, h_i_mean], dim=-1) + else: + proj_input = h_mean + latents[v] = self.output_projs[v](proj_input) + return latents diff --git a/tests/core/test_multiview_contrastive.py b/tests/core/test_multiview_contrastive.py new file mode 100644 index 000000000..aa7dd4e41 --- /dev/null +++ b/tests/core/test_multiview_contrastive.py @@ -0,0 +1,273 @@ +"""Tests for MultiViewContrastive model. + +Uses small synthetic tensors (2-5 samples) so all tests complete in under +one second. No real datasets are downloaded or used. +""" + +import unittest + +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MultiViewContrastive + + +def _make_dataset(n_samples: int = 4, n_channels: int = 1, length: int = 178): + """Create a tiny synthetic dataset for testing.""" + rng = np.random.RandomState(42) + samples = [] + for i in range(n_samples): + samples.append( + { + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": rng.randn(n_channels, length).astype(np.float32), + "label": i % 3, + } + ) + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="test_mvc", + ) + + +class TestMultiViewContrastiveInit(unittest.TestCase): + """Test model initialization with various configurations.""" + + def setUp(self): + self.dataset = _make_dataset() + + def test_default_init(self): + model = MultiViewContrastive(dataset=self.dataset) + self.assertEqual(model.encoder_type, "transformer") + self.assertEqual(model.view_type, "ALL") + self.assertEqual(model.fusion_type, "attention") + self.assertEqual(model.num_embedding, 64) + + def test_cnn_init(self): + model = MultiViewContrastive( + dataset=self.dataset, encoder_type="cnn" + ) + self.assertEqual(model.encoder_type, "cnn") + + def test_gru_init(self): + model = MultiViewContrastive( + dataset=self.dataset, encoder_type="gru" + ) + self.assertEqual(model.encoder_type, "gru") + + def test_invalid_encoder_type(self): + with self.assertRaises(AssertionError): + MultiViewContrastive( + dataset=self.dataset, encoder_type="lstm" + ) + + def test_invalid_view_type(self): + with self.assertRaises(AssertionError): + MultiViewContrastive( + dataset=self.dataset, view_type="XYZ" + ) + + +class TestMultiViewContrastiveForward(unittest.TestCase): + """Test forward pass for all encoder x view x fusion combos.""" + + def setUp(self): + self.dataset = _make_dataset() + + def _run_forward(self, encoder_type, view_type, fusion_type): + model = MultiViewContrastive( + dataset=self.dataset, + encoder_type=encoder_type, + view_type=view_type, + fusion_type=fusion_type, + num_embedding=16, + num_hidden=32, + num_head=2, + num_layers=1, + dropout=0.0, + ) + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + with torch.no_grad(): + ret = model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertEqual(ret["y_prob"].shape[0], 4) + self.assertEqual(ret["loss"].dim(), 0) + return ret + + def test_transformer_all_attention(self): + self._run_forward("transformer", "ALL", "attention") + + def test_transformer_all_concat(self): + self._run_forward("transformer", "ALL", "concat") + + def test_transformer_all_mean(self): + self._run_forward("transformer", "ALL", "mean") + + def test_cnn_all_attention(self): + self._run_forward("cnn", "ALL", "attention") + + def test_gru_all_attention(self): + self._run_forward("gru", "ALL", "attention") + + def test_transformer_single_view_T(self): + self._run_forward("transformer", "T", "concat") + + def test_transformer_single_view_D(self): + self._run_forward("transformer", "D", "mean") + + def test_transformer_single_view_F(self): + self._run_forward("transformer", "F", "concat") + + def test_transformer_dual_view_TD(self): + self._run_forward("transformer", "TD", "attention") + + def test_transformer_dual_view_TF(self): + self._run_forward("transformer", "TF", "concat") + + def test_transformer_dual_view_DF(self): + self._run_forward("transformer", "DF", "mean") + + def test_cnn_single_view(self): + self._run_forward("cnn", "T", "concat") + + def test_gru_dual_view(self): + self._run_forward("gru", "TD", "attention") + + +class TestMultiViewContrastiveBackward(unittest.TestCase): + """Test gradient computation.""" + + def setUp(self): + self.dataset = _make_dataset() + + def test_gradient_flow_transformer(self): + model = MultiViewContrastive( + dataset=self.dataset, + encoder_type="transformer", + num_embedding=16, + num_hidden=32, + num_head=2, + num_layers=1, + ) + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + + ret = model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None + for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_gradient_flow_cnn(self): + model = MultiViewContrastive( + dataset=self.dataset, + encoder_type="cnn", + num_embedding=16, + num_hidden=32, + num_layers=2, + ) + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + + ret = model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None + for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_gradient_flow_gru(self): + model = MultiViewContrastive( + dataset=self.dataset, + encoder_type="gru", + num_embedding=16, + num_hidden=32, + num_layers=2, + ) + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + + ret = model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None + for p in model.parameters() + ) + self.assertTrue(has_grad) + + +class TestMultiViewContrastiveEmbed(unittest.TestCase): + """Test embedding output and encode_views helper.""" + + def setUp(self): + self.dataset = _make_dataset() + + def test_embed_output(self): + model = MultiViewContrastive( + dataset=self.dataset, + num_embedding=16, + num_hidden=32, + num_head=2, + num_layers=1, + ) + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + batch["embed"] = True + with torch.no_grad(): + ret = model(**batch) + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 4) + self.assertEqual(ret["embed"].dim(), 2) + + def test_encode_views(self): + model = MultiViewContrastive( + dataset=self.dataset, + num_embedding=16, + num_hidden=32, + num_head=2, + num_layers=1, + ) + x = torch.randn(2, 1, 178) + with torch.no_grad(): + latents = model.encode_views(x) + self.assertEqual(len(latents), 3) + for v in ["t", "d", "f"]: + self.assertIn(v, latents) + self.assertEqual(latents[v].shape, (2, 32)) + + +class TestComputeViews(unittest.TestCase): + """Test static view computation.""" + + def test_shapes(self): + x = torch.randn(4, 1, 178) + xt, dx, xf = MultiViewContrastive.compute_views(x) + self.assertEqual(xt.shape, (4, 178, 1)) + self.assertEqual(dx.shape, (4, 178, 1)) + self.assertEqual(xf.shape, (4, 178, 1)) + + def test_multichannel(self): + x = torch.randn(2, 3, 206) + xt, dx, xf = MultiViewContrastive.compute_views(x) + self.assertEqual(xt.shape, (2, 206, 3)) + self.assertEqual(dx.shape, (2, 206, 3)) + self.assertEqual(xf.shape, (2, 206, 3)) + + +if __name__ == "__main__": + unittest.main()