From 1037758875ac44c21203a867230766584d372359 Mon Sep 17 00:00:00 2001 From: jordynhayden Date: Wed, 25 Mar 2026 18:44:42 -0400 Subject: [PATCH] Add Random Forest Model and LOS Task Contribution for CS598 Final Project --- docs/api/models.rst | 1 + .../models/pyhealth.models.RandomForest.rst | 10 + ...health.tasks.length_of_stay_prediction.rst | 5 + .../mimic3_length_of_stay_random_forest.py | 173 +++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/random_forest.py | 479 ++++++++++++++++++ pyhealth/models/utils.py | 179 ++++++- pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/length_of_stay_prediction.py | 179 ++++++- tests/core/test_data_loader_to_numpy_util.py | 217 ++++++++ tests/core/test_mimic3_threshold_los.py | 208 ++++++++ tests/core/test_random_forest.py | 360 +++++++++++++ 12 files changed, 1809 insertions(+), 4 deletions(-) create mode 100644 docs/api/models/pyhealth.models.RandomForest.rst create mode 100644 examples/length_of_stay/mimic3_length_of_stay_random_forest.py create mode 100644 pyhealth/models/random_forest.py create mode 100644 tests/core/test_data_loader_to_numpy_util.py create mode 100644 tests/core/test_mimic3_threshold_los.py create mode 100644 tests/core/test_random_forest.py diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..d2aaa402c 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -178,6 +178,7 @@ API Reference models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel models/pyhealth.models.TransformerDeID + models/pyhealth.models.RandomForest models/pyhealth.models.RETAIN models/pyhealth.models.GAMENet models/pyhealth.models.GraphCare diff --git a/docs/api/models/pyhealth.models.RandomForest.rst b/docs/api/models/pyhealth.models.RandomForest.rst new file mode 100644 index 000000000..8d6c9ca58 --- /dev/null +++ b/docs/api/models/pyhealth.models.RandomForest.rst @@ -0,0 +1,10 @@ +pyhealth.models.RandomForest +=================================== + +Wraps sklearn's RandomForestClassifier for classification tasks and sklearn's RandomForestRegressor for regression tasks +for use with PyHealth pipelines. + +.. autoclass:: pyhealth.models.RandomForest + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.length_of_stay_prediction.rst b/docs/api/tasks/pyhealth.tasks.length_of_stay_prediction.rst index 727a61767..af69f2799 100644 --- a/docs/api/tasks/pyhealth.tasks.length_of_stay_prediction.rst +++ b/docs/api/tasks/pyhealth.tasks.length_of_stay_prediction.rst @@ -9,6 +9,11 @@ Task Classes :undoc-members: :show-inheritance: +.. autoclass:: pyhealth.tasks.length_of_stay_prediction.LengthOfStayThresholdPredictionMIMIC3 + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: pyhealth.tasks.length_of_stay_prediction.LengthOfStayPredictionMIMIC4 :members: :undoc-members: diff --git a/examples/length_of_stay/mimic3_length_of_stay_random_forest.py b/examples/length_of_stay/mimic3_length_of_stay_random_forest.py new file mode 100644 index 000000000..00dad22f1 --- /dev/null +++ b/examples/length_of_stay/mimic3_length_of_stay_random_forest.py @@ -0,0 +1,173 @@ +""" +This study uses a RandomForest model on the MIMIC-III dataset to predict whether length +of stay exceeds 3 days, with hyperparameter tuning to maximize AUROC. + +Setup: +- Dataset: MIMIC-III +- Task: Binary classification (Length of Stay > 3 days) +- Patient Data Split: 70% train / 10% validation / 20% test +- Evaluation metric: AUROC + +Hyperparameter Tuning: +Grid search over the following hyperparameters: +- n_estimators: [100, 200, 300] +- max_depth: [5, 7, 10] +- min_samples_leaf: [1, 2] +- min_samples_split: [2, 3] +- class_weight: ["balanced", None] +- bootstrap: [True, False] + +Each hyperparameter configuration was trained on the training set and evaluated on the +validation set. Results were ranked by AUROC. + +Hyperparameter Tuning Findings +- The found best-performing random forest classifier model configuration achieved an +AUROC of ~0.77 using: bootstrap = True, class_weight = None, max_depth = 5, +min_samples_leaf = 1, min_samples_split = 2, and n_estimators = 200. Average AUROC +over all tuned parameters was ~0.70, Min: 0.52 +- Shallow trees improve performance. Likely due to limited number of patients in dataset +- Increasing n_estimators (trees) improved performance +- Using class_weight="balanced" reduced AUROC + +Final Model +The best hyperparameter configuration was used to train a final model, which was then +evaluated on the test set. + +This experiment also serves as am example for how to use the PyHealth +RandomForest model and the Length of Stay Threshold binary prediction task +demonstrating: +1. Loading MIMIC-III data +2. Setting the Length of Stay Greater Than X Days Prediction task +3. Splitting the dataset and getting the dataloaders +4. Tuning hyperparameters +5. Creating and Fitting a RandomForest model +6. Evaluating a Random Forest model +""" +import tempfile + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import RandomForest +from pyhealth.tasks.length_of_stay_prediction import \ + LengthOfStayThresholdPredictionMIMIC3 + + +def print_results_table(performance_results): + """ Helper to print the results of performance comparison across different + configurations. + """ + if not performance_results: + print("No results to display.") + return + + print("\n" + "*" * 25) + print("Hyperparameter Tuning Results...") + print("*" * 25 + "\n") + + # Alphabetize + param_keys = sorted(performance_results[0]["params"].keys()) + + # Create a table header of the hyperparams and a column for the resulting auroc + # score + header = param_keys + ["auroc score"] + print(" | ".join(f"{h:^15}" for h in header)) + print("-" * (18 * len(header))) + + # Print each row + for p in performance_results: + row = [p["params"].get(k) for k in param_keys] + [p["score"]] + print(" | ".join(f"{str(v):^15}" for v in row)) + + +if __name__ == "__main__": + + # Constants + BATCH_SIZE = 32 + + # STEP 1: Load Dataset + print("\n" + "*" * 25) + print("Loading MIMIC3 Dataset...") + print("*" * 25 + "\n") + base_dataset = MIMIC3Dataset( + root = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III", + tables = ["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], + cache_dir = tempfile.TemporaryDirectory().name, + dev = True, + ) + base_dataset.stats() + + # STEP 2: Set Task + # Define the Length of Stay > 3 prediction classification task + print("\n" + "*" * 25) + print("Setting a LOS Threshold Binary Prediction Task...") + print("*" * 25 + "\n") + task = LengthOfStayThresholdPredictionMIMIC3(3, exclude_minors = False) + sample_dataset = base_dataset.set_task(task) + + # STEP 3: Split Datasets + print("\n" + "*" * 25) + print("Creating Dataset Splits...") + print("*" * 25 + "\n") + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.7, 0.1, 0.2] + ) + + train_loader = get_dataloader(train_dataset, batch_size = BATCH_SIZE, + shuffle = True) + val_loader = get_dataloader(val_dataset, batch_size = BATCH_SIZE, shuffle = False) + test_loader = get_dataloader(test_dataset, batch_size = BATCH_SIZE, shuffle = False) + + # STEP 4: Define hyperparameters to tune and conduct hyperparameter tuning loop + # Here we define a hyperparameter dictionary of values to try out in order to + # determine what model configuration yields the best metrics over the validation + # dataset + tuning_params = { + "n_estimators": [100, 200, 300], + "max_depth": [5, 7, 10], + "min_samples_leaf": [1, 2], + "min_samples_split": [2, 3], + "class_weight": ["balanced", None], + "bootstrap": [True, False], + } + + print("\n" + "*" * 20) + print("Tuning Parameters...") + print("*" * 20 + "\n") + + best_params, best_score, results = RandomForest.tune( + sample_dataset, + train_loader, + val_loader, + tuning_params, + return_all = True + ) + + # Sort results so best auroc is on top + results = sorted( + results, + key = lambda x: x["score"] if x["score"] is not None else -float("inf"), + reverse = True, + ) + print_results_table(results) + + print("\nBest Hyperparameter Combination:", best_params, best_score) + + # STEP 5: Create Final Random Forest Model + final_model = RandomForest( + dataset = sample_dataset, + **best_params if best_score is not None else {}, + ) + final_model.fit(train_loader) + + # STEP 6: Final Evaluation on Test Dataset + print("*" * 41) + print("Final Evaluation Using Best Parameters...") + print("*" * 41 + "\n") + + test_metrics = final_model.evaluate(test_loader) + + print("\n" + "*" * 58) + print("Final Metrics Using Best Parameters on the Test Dataset...") + print("*" * 58 + "\n") + + for k, v in test_metrics.items(): + print(f"{k}: {v:.4f}" if v is not None else f"{k}: None") diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..e8a797ad3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -21,6 +21,7 @@ from .mlp import MLP from .molerec import MoleRec, MoleRecLayer from .retain import MultimodalRETAIN, RETAIN, RETAINLayer +from .random_forest import RandomForest from .rnn import MultimodalRNN, RNN, RNNLayer from .safedrug import SafeDrug, SafeDrugLayer from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer diff --git a/pyhealth/models/random_forest.py b/pyhealth/models/random_forest.py new file mode 100644 index 000000000..be1e7a56c --- /dev/null +++ b/pyhealth/models/random_forest.py @@ -0,0 +1,479 @@ +""" +Provides an implementation of the Random Forest model that is compatible with PyHealth +pipelines. +""" +import logging +from itertools import product +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from torch import nn +from torch.utils.data import DataLoader + +from pyhealth.datasets import SampleEHRDataset +from pyhealth.metrics import binary_metrics_fn, regression_metrics_fn +from pyhealth.models import BaseModel +from pyhealth.models.utils import DataLoaderToNumpy + +logger = logging.getLogger(__name__) + + +class RandomForest(BaseModel): + """Random Forest model. + + Wraps sklearn's RandomForestClassifier for classification tasks and + sklearn's RandomForestRegressor for regression tasks. + + Args: + dataset: dataset to train the model. + pad_batches: whether to pad batches such that their dimensions are equal. + Default is True. If False, a ValueError will be raised if padding is + needed and dimensions are not equal. + f1_average: averaging strategy for F1 in evaluate(). Can be "macro", + "weighted", or "binary". Default is "macro". + n_estimators: number of trees in the forest. Default is 100. + criterion: function to measure the quality of a split. For classification + tasks this can be either: "gini", "entropy", "log_loss". For regression, + this can be either: "friedman_mse", "squared_error", "absolute_error", + "poisson". If None, criterion is set to "gini" for classification tasks + and "friedman_mse" for regression tasks. + max_depth: maximum depth of the tree. If None, nodes are expanded until all + leaves are pure or contain fewer than min_samples_split samples. Default + is None. + min_samples_split: minimum number of samples required to split an internal node. + Default is 2. + min_samples_leaf: minimum number of samples required to be at a leaf node. + Default is 1. + min_weight_fraction_leaf: minimum weighted fraction of the sum total of weights + required to be at a leaf node. Default is 0.0. + max_features: number of features to consider when looking for the best split. + Can be {"sqrt", "log2", None}, int or float. Default is "sqrt". + max_leaf_nodes: grow trees with max_leaf_nodes in best-first fashion. + If None, then unlimited number of leaf nodes. Default is None. + min_impurity_decrease: A node will be split if this split induces a decrease of + the impurity greater than or equal to this value. Default is 0.0. + bootstrap: whether bootstrap samples are used when building trees. + Default is True. + oob_score: whether to use out-of-bag samples to estimate the generalization + score. Default is False. + n_jobs: number of jobs to run in parallel. None means 1 unless in a joblib + parallel backend context. -1 uses all converters Default is None. + random_state: controls randomness. Default is 42. + verbose: controls verbosity when fitting and predicting. Default is 0. + warm_start: reuse the solution of the previous call to fit and add more + estimators to the ensemble. Default is False. + class_weight: weights associated with classes. Can be {“balanced”, + “balanced_subsample”}, dict or list of dicts, + Default is None. + ccp_alpha: complexity parameter used for Minimal Cost-Complexity Pruning. + Default is 0.0. + max_samples: If bootstrap is True, the number of samples to draw from X to + train each base estimator. Default = None. + monotonic_cst: Indicates the monotonicity constraint to enforce on each + feature. 1: monotonic increase, 0: no constraint, -1: monotonic decrease. + If monotonic_cst is None, no constraints are applied. Default is None. + **kwargs: Additional arguments passed to pass to a deeper layer + + Raises: + ValueError, TypeError: If invalid model parameters are found. + """ + # Criterion Constants + GINI = "gini" + FRIEDMAN_MSE = "friedman_mse" + CLASSIFICATION_CRITERIA = {"gini", "entropy", "log_loss"} + REGRESSION_CRITERIA = {"squared_error", "friedman_mse", "poisson", + "absolute_error"} + + # F1 Avg Constants + MACRO = "macro" + WEIGHTED = "weighted" + BINARY = "binary" + VALID_F1_AVG = {MACRO, WEIGHTED, BINARY} + + # Key Constants + KEY_Y_TRUE = "y_true" + KEY_Y_PROB = "y_prob" + KEY_LOGIT = "logit" + KEY_LOSS = "loss" + METRIC_ACCURACY = "accuracy" + METRIC_F1 = "f1" + METRIC_AUROC = "roc_auc" + METRIC_MSE = "mse" + METRIC_MAE = "mae" + + def __init__( + self, + dataset: SampleEHRDataset, + pad_batches: bool = True, + f1_average: str = "macro", + n_estimators: int = 100, + criterion: str = None, + max_depth: Optional[int] = None, + min_samples_split: int = 2, + min_samples_leaf: int = 1, + min_weight_fraction_leaf: float = 0.0, + max_features: Union[str, int, float, None] = "sqrt", + max_leaf_nodes: Optional[int] = None, + min_impurity_decrease: float = 0.0, + bootstrap: bool = True, + oob_score: bool = False, + n_jobs: Optional[int] = None, + random_state: Optional[int] = 42, + verbose: int = 0, + warm_start: bool = False, + class_weight: Optional[str] = None, + ccp_alpha: float = 0.0, + max_samples: Optional[int] = None, + **kwargs, + ): + # Call base constructor + super(RandomForest, self).__init__(dataset = dataset) + + # Save off inputs + self.pad_batches = pad_batches + self._is_fitted = False + self.feature_dim = None + self.f1_average = f1_average + + # Validate label keys + assert len(self.label_keys) == 1, "Only one label key is supported" + self.label_key = self.label_keys[0] + + # Verify f1_average is valid + if not isinstance(self.f1_average, str): + raise TypeError("Input f1_average must be a string.") + if self.f1_average.lower() not in self.VALID_F1_AVG: + raise ValueError("Input f1_average unsupported.") + + # If no criterion is given, set the criterion based on the task whether it is + # binary or regression + if criterion is None: + if self.mode != "regression": + criterion = self.GINI + else: + criterion = self.FRIEDMAN_MSE + + # Validate criterion + if self.mode != "regression": + valid_criterion = RandomForest.CLASSIFICATION_CRITERIA + else: + valid_criterion = RandomForest.REGRESSION_CRITERIA + + if criterion not in valid_criterion: + raise ValueError("Input criterion unsupported.") + + # Create a numpy converter since sklearn expects numpy arrays + self.converter = DataLoaderToNumpy( + feature_keys = self.feature_keys, + label_key = self.label_key, + pad_batches = pad_batches, + ) + + # Store parameters into a dictionary that we can pass through whether we + # instantiate a classifier or regressor sklearn model based on the mode + hyperparams = dict( + n_estimators = n_estimators, + criterion = criterion, + max_depth = max_depth, + min_samples_split = min_samples_split, + min_samples_leaf = min_samples_leaf, + min_weight_fraction_leaf = min_weight_fraction_leaf, + max_features = max_features, + max_leaf_nodes = max_leaf_nodes, + min_impurity_decrease = min_impurity_decrease, + bootstrap = bootstrap, + oob_score = oob_score, + n_jobs = n_jobs, + random_state = random_state, + verbose = verbose, + warm_start = warm_start, + ccp_alpha = ccp_alpha, + max_samples = max_samples, + ) + + # Create internal models + try: + if self.mode == "regression": + self.model = RandomForestRegressor(**hyperparams) + else: + # RandomForestClassifier also includes class weight unlike regressor + self.model = RandomForestClassifier( + class_weight = class_weight, + **hyperparams) + except TypeError: + raise TypeError("Invalid model parameters") + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + """Forward propagation (inference). + + Args: + **kwargs: A variable number of keyword arguments representing input + features. Each keyword argument is a tensor or a tuple of tensors of + shape (batch_size, ...). + + Returns: + a dictionary containing keys loss, y_prob, y_true, logit + + Raises: + RuntimeError: if fit() has not been called. + """ + if not self._is_fitted: + raise RuntimeError("Model has not been fitted.") + + x, y_np = self.converter.transform([kwargs]) + + return self._generate_inference_output(self._predict_numpy(x), y_np) + + def fit(self, dataloader: DataLoader): + """Fit the Random Forest model to the dataloader + + Args: + dataloader: PyTorch DataLoader + + Returns: + self + """ + # Get numpy matrices for use with sklearn + x, y = self.converter.transform(dataloader) + + # Fit the model and flag that the model has been fitted + self.model.fit(x, y) + + # Track that the model is fitted + self._is_fitted = True + + return self + + def evaluate(self, dataloader: DataLoader) -> Dict[str, Optional[float]]: + """Evaluates the model and returns calculated metrics. + + Args: + dataloader: PyTorch DataLoader for evaluation (test dataset) + + Returns: + Dictionary of metrics with different keys depending on the mode: + - classification: 'accuracy', 'f1', 'auroc'. auroc may be None. + - regression: 'mse', 'mae' + Raises: + RuntimeError: if model has not been fitted. + ValueError: Error during metric computation. + """ + if not self._is_fitted: + raise RuntimeError("Model has not been fitted.") + + # Get numpy matrices for use with sklearn + x, y_true = self.converter.transform(dataloader) + y_prob_np = self._predict_numpy(x) + + results = {} + try: + if self.mode != "regression": + # use pyhealth metric utils + results = binary_metrics_fn( + y_true = y_true, + y_prob = y_prob_np[:, 1], + metrics = ["roc_auc", "f1", "accuracy"], + ) + else: + # use pyhealth metric utils + results = regression_metrics_fn( + x = y_true.copy(), + x_rec = y_prob_np.view(-1).reshape(-1), + metrics = ["mse", "mae"] + ) + except ValueError: + logger.warning("Failed to compute metrics.") + + return results + + def get_params(self) -> Dict[str, Any]: + """Returns the model parameters. Wraps sklearn's get_params() and updates the + dictionary with additional parameters used by this wrapper class. This is + useful to verify the model has been initialized correctly. + + Returns: + Dictionary of model parameters + """ + params = self.model.get_params() + params.update({ + "pad_batches": self.pad_batches, + "f1_average": self.f1_average + }) + + return params + + def _calculate_loss(self, logit, y_true) -> torch.Tensor: + """Compute loss between predictions and labels. + + Args: + logit: Tensor outputs + y_true: Tensor labels + + Returns: + : loss tensor + """ + + if self.mode != "regression": + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(logit, y_true.long().view(-1)) + else: + loss_fn = nn.MSELoss() + loss = loss_fn(logit, y_true.float().view(-1, 1)) + + return loss + + def _generate_inference_output( + self, + y_prob_np: np.ndarray, + y_np: np.ndarray, + ) -> Dict[str, torch.Tensor]: + """Generates the inference output in the format expected by PyHealth + + Args: + y_prob_np: numpy array of shape (n_samples, n_classes) + y_np: numpy array of shape (n_samples,) + + Returns: + Dictionary with keys loss, y_prob, y_true, logit + """ + y_prob = torch.from_numpy(y_prob_np) + y_true = torch.from_numpy(y_np) + + return { + self.KEY_LOSS: self._calculate_loss(y_prob, y_true), + self.KEY_Y_PROB: y_prob, + self.KEY_Y_TRUE: y_true, + self.KEY_LOGIT: y_prob, + } + + def _predict_numpy(self, x: np.ndarray) -> np.ndarray: + """Run prediction using the internal sklearn model. + + Args: + x: numpy array of shape (n_samples, n_features). + + Returns: + numpy array of shape (n_samples, n_classes) + """ + if isinstance(self.model, RandomForestClassifier): + prediction_results = self.model.predict_proba(x).astype(np.float32) + else: + prediction_results = self.model.predict(x).astype(np.float32).reshape(-1, 1) + + return prediction_results + + @staticmethod + def tune( + dataset: SampleEHRDataset, + train_loader: DataLoader, + val_loader: DataLoader, + param_grid: Dict[Any, Any], + task: str = "classification", + fixed_params: Optional[Dict[Any, Any]] = None, + metric: str = "roc_auc", + maximize: bool = True, + return_all: bool = False, + ) -> Union[ + tuple[Optional[dict[Any, Any]], Union[float, Any]], + tuple[Optional[dict[Any, Any]], Union[float, Any], list], + ]: + """ + Performs a "grid search" over the given dictionary of model parameters and + returns the best combination of parameters. Best meaning, the parameters that + either maximized or minimized (according to the maximize parameters) the given + matric. + + Args: + dataset: PyHealth dataset object + train_loader: Training dataloader + val_loader: Validation dataloader + param_grid: Dict of hyperparameters to search + task: task type, either "classification" or "regression". Default is + "classification". + fixed_params: Params that stay constant + metric: Metric to optimize (e.g., 'roc_auc', 'mae') + maximize: 'max' or 'min' + maximize: True to maximize the metric, False otherwise + return_all: True to get back all parameter combination and metric results + or False to receive only the best performing combination of parameters. + Defaults to False. + + Returns: + best_params: Dictionary of the best combination of parameters + best_score: Best metric + results: list of all the parameter combinations and their metric result + if return_all is enabled + Raises: + ValueError: if task or metric is not a string, or unsupported. + """ + # Validate that the task is supported + if not isinstance(task, str): + raise TypeError("Input task is expected to be a string.") + if task.lower() not in ["classification", "regression"]: + raise ValueError("Input task unsupported.") + + # Define what metrics are supported based on the given task + if task == "classification": + valid_metrics = [RandomForest.METRIC_F1, RandomForest.METRIC_ACCURACY, + RandomForest.METRIC_AUROC] + if not isinstance(metric, str) or metric not in valid_metrics: + raise ValueError("Input metric unsupported.") + else: + valid_metrics = [RandomForest.METRIC_MSE, RandomForest.METRIC_MAE] + if not isinstance(metric, str) or metric not in valid_metrics: + raise ValueError("Input metric unsupported.") + fixed_params = fixed_params or {} + + # Determine if we want to see the highest or lowest score form this metric + if maximize: + best_score = -float("inf") + else: + best_score = float("inf") + best_params = None + + keys = list(param_grid.keys()) + values = list(param_grid.values()) + + results = [] + for parameter_combination in product(*values): + params = dict(zip(keys, parameter_combination)) + + # Gather parameters + all_params = {**fixed_params, **params} + + # Create the model with the given params + model = RandomForest( + dataset = dataset, + **all_params + ) + + # Fit the model to the training data + model.fit(train_loader) + + # Calculate metrics + metrics = model.evaluate(val_loader) + + # Get the specific metric of interest + score = metrics[metric] + + results.append({ + "params": all_params, + "score": score + }) + + # Keep track of the best score and params noting that we may either me + # maximizing or minimizing + if (score is not None + and ((maximize and score > best_score) or + (not maximize and score < best_score))): + best_score = score + best_params = all_params + + if return_all: + tuning_results = best_params, best_score, results + else: + tuning_results = best_params, best_score + + return tuning_results diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py index 67edc010e..29392c1a0 100644 --- a/pyhealth/models/utils.py +++ b/pyhealth/models/utils.py @@ -1,6 +1,11 @@ -from typing import List +""" +Provides utilities for PyHealth models +""" +from typing import Dict, List +import numpy as np import torch +from torch.utils.data import DataLoader def batch_to_multihot(label: List[List[int]], num_labels: int) -> torch.tensor: @@ -44,3 +49,175 @@ def get_last_visit(hidden_states, mask): last_hidden_states = torch.gather(hidden_states, 1, last_visit) last_hidden_state = last_hidden_states[:, 0, :] return last_hidden_state + + +class DataLoaderToNumpy: + """ + Converts a DataLoader to numpy arrays for sklearn models. + + Args: + feature_keys: list of feature keys + label_key: label key + + Examples: + >>> converter = DataLoaderToNumpy( + ... feature_keys=["conditions", "procedures"], + ... label_key="los" + ... ) + >>> X, y = converter.transform(dataloader) + """ + + def __init__(self, + feature_keys: List[str], + label_key: str, + pad_batches: bool = True): + """ + Initializes the dataloader to numpy converter. + + Args: + feature_keys: Keys for input features to include. + label_key: Key for the target label. + pad_batches: Whether to pad features to a consistent size + across batches. Defaults to True. + + Examples: + >>> converter = DataLoaderToNumpy( + ... feature_keys=["conditions", "procedures"], + ... label_key="los" + ... ) + """ + + self.feature_keys = feature_keys + self.label_key = label_key + + self.pad_batches = pad_batches + self._key_dims: Dict[str, int] = {} + self._fitted = False + + def transform(self, dataloader: DataLoader) -> tuple[np.ndarray, np.ndarray]: + """Converts a DataLoader to numpy arrays for sklearn models. + + Args: + dataloader: PyHealth DataLoader + + Returns: + X: numpy array of shape (n_samples, n_features). + y: numpy array of shape (n_samples,). + + Examples: + >>> X, y = converter.transform(dataloader) + >>> X.shape + (num_samples, num_features) + >>> y.shape + (num_samples,) + """ + x_parts: List[np.ndarray] = [] + y_parts: List[np.ndarray] = [] + + for batch in dataloader: + x_parts.append(self._process_features(batch)) + y_parts.append(self._process_labels(batch)) + + self._fitted = True + + return np.vstack(x_parts), np.concatenate(y_parts) + + @staticmethod + def _to_numpy(value: torch.Tensor) -> np.ndarray: + """Converts a Tensor or list to a numpy array + + Args: + value: torch.Tensor, numpy array, or Python list from a DataLoader batch. + + Returns: + numpy array. + + Examples: + >>> import torch + >>> arr = DataLoaderToNumpy._to_numpy(torch.tensor([1, 2, 3])) + """ + if isinstance(value, torch.Tensor): + arr = value.detach().cpu().numpy() + elif isinstance(value, np.ndarray): + arr = value + else: + arr = np.array(value) + + return arr.astype(np.float32) + + def _flatten_feature(self, arr: np.ndarray, key: str) -> np.ndarray: + """Converts the feature to a two-dimensional array and pads if padding is + enabled and needed. + + Args: + arr: numpy array + key: feature key + + Returns: + numpy array of shape (batch_size, expected_width). + + Raises: + ValueError: if dimensions are not consistent and pad_batches is False. + + Examples: + >>> arr = np.array([[1, 2], [3, 4]]) + >>> converter._flatten_feature(arr, "conditions") + """ + if arr.ndim == 1: + arr = arr.reshape(-1, 1) + elif arr.ndim > 2: + arr = arr.reshape(arr.shape[0], -1) + + width = arr.shape[1] + + if not self._fitted: + self._key_dims[key] = width + elif width != self._key_dims[key]: + expected = self._key_dims[key] + + if not self.pad_batches: + raise ValueError( + f"Inconsistent batch sizes across features. Set pad_batches=True t" + f"o allow padding." + ) + arr = ( + np.pad(arr, ((0, 0), (0, expected - width)), mode = "constant") + if width < expected + else arr[:, :expected] + ) + + return arr + + def _process_features(self, batch: dict) -> np.ndarray: + """Concatenate all features from one batch. + + Args: + batch: dictionary from PyHealth DataLoader containing feature keys + + Returns: + numpy array of shape (batch_size, total_feature_dim). + + Examples: + >>> X_batch = converter._process_features(batch) + """ + return np.concatenate( + [self._flatten_feature(self._to_numpy(batch[k]), k) for k in + self.feature_keys], + axis = 1, + ) + + def _process_labels(self, batch: dict) -> np.ndarray: + """Extract and flatten the label array from one batch. + + Args: + batch: dictionary from a PyHealth DataLoader containing label_key. + + Returns: + numpy array of shape (batch_size,). + + Examples: + >>> y_batch = converter._process_labels(batch) + >>> y_batch.shape + (batch_size,) + """ + return self._to_numpy(batch[self.label_key]).reshape(-1) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..0d5800771 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -28,6 +28,7 @@ LengthOfStayPredictionMIMIC3, LengthOfStayPredictionMIMIC4, LengthOfStayPredictionOMOP, + LengthOfStayThresholdPredictionMIMIC3, ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 from .medical_coding import MIMIC3ICD9Coding diff --git a/pyhealth/tasks/length_of_stay_prediction.py b/pyhealth/tasks/length_of_stay_prediction.py index 25e0c3121..bed25ed3c 100644 --- a/pyhealth/tasks/length_of_stay_prediction.py +++ b/pyhealth/tasks/length_of_stay_prediction.py @@ -1,8 +1,7 @@ -from datetime import datetime, timedelta -from typing import Dict, List +from datetime import datetime +from typing import Any, Dict, List from pyhealth.data.data import Patient - from .base_task import BaseTask @@ -131,6 +130,175 @@ def __call__(self, patient: Patient) -> List[Dict]: return samples +class LengthOfStayThresholdPredictionMIMIC3(BaseTask): + """Task for predicting whether length of stay exceeded a certain number of days + using the MIMIC-III dataset. + + Length of stay prediction aims at predicting the length of stay (in days) of the + current hospital visit based on the clinical information from the visit + (e.g., conditions and procedures). + + Args: + days: Threshold days + + Raises: + TypeError: if days is not an integer. + ValueError: if days is not a positive integer. + + Attributes: + task_name: The name of the task. + input_schema: The schema for input data, which includes: + - conditions: A list of condition codes. + - procedures: A list of procedure codes. + - drugs: A list of drug codes. + output_schema: The schema for output data, which includes: + - los: A binary class label for whether length of stay exceeded the given + number of days. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import LengthOfStayPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", + ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], + ... code_mapping={"ICD9CM": "CCSCM"}, + ... ) + >>> task = LengthOfStayPredictionMIMIC3(3.0) + >>> mimic3_sample = dataset.set_task(task) + """ + task_name: str = "LengthOfStayThresholdPredictionMIMIC3" + + input_schema = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + + output_schema: Dict[str, str] = {"los": "binary"} + + def __init__(self, days: float = 3, exclude_minors: bool = True): + """ + Initializes the length-of-stay prediction task. + + Args: + days: Threshold in days. LOS > days → label = 1. + exclude_minors: Whether to exclude minor patients whose age is less than + 18. Defaults to True. + """ + if not isinstance(days, (int, float)): + raise TypeError("Days must be a number (int or float)") + if days <= 0: + raise ValueError("Days must be greater than 0") + + self.days = float(days) + self.exclude_minors = exclude_minors + + def __call__(self, patient: Any) -> List[Dict]: + """ + Generates binary length-of-stay (LOS) prediction samples for a single patient. + + Each admission is converted into one sample with a binary label indicating + whether the length of stay exceeds a specified threshold (``self.days``). + + Visits with no conditions OR no procedures OR no drugs are excluded from + the output. + + Args: + patient: A patient object (expected to implement get_events()) + + Returns: + List[Dict]: A list containing a dictionary for each valid admission with: + - 'visit_id': MIMIC3 hadm_id. + - 'patient_id': MIMIC3 subject_id. + - 'conditions': Diagnosis codes from the diagnoses_icd table. + - 'procedures': Procedure codes from the procedures_icd table. + - 'drugs': Drug codes from the prescriptions table. + - 'los': Binary label where 1 indicates LOS > ``self.days`` and 0 + otherwise. + + Raises: + ValueError: If date strings (e.g., date of birth or discharge time) + cannot be parsed into datetime objects. + """ + samples = [] + + # Get all admissions + admissions = patient.get_events(event_type = "admissions") + if len(admissions) == 0: + return [] + + patients = patient.get_events(event_type = "patients") + assert len(patients) == 1 + + # check for minor (patients less than 18 years old) exclusion + if self.exclude_minors: + try: + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d %H:%M:%S") + except ValueError: + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d") + + # Process each admission + for admission in admissions: + if self.exclude_minors: + age = admission.timestamp.year - dob.year + if (admission.timestamp.month, admission.timestamp.day) < (dob.month, + dob.day): + # Patient's birthday has not yet occurred, adjust age + age -= 1 + if age < 18: + # Exclude minors + continue + + # Get diagnosis codes using hadm_id + diagnoses_events = patient.get_events( + event_type = "diagnoses_icd", + filters = [("hadm_id", "==", admission.hadm_id)], + ) + conditions = [event.icd9_code for event in diagnoses_events] + + # Get procedure codes using hadm_id + procedures_events = patient.get_events( + event_type = "procedures_icd", + filters = [("hadm_id", "==", admission.hadm_id)], + ) + procedures = [event.icd9_code for event in procedures_events] + + # Get prescriptions using hadm_id + prescriptions_events = patient.get_events( + event_type = "prescriptions", + filters = [("hadm_id", "==", admission.hadm_id)], + ) + drugs = [event.ndc for event in prescriptions_events] + + # Exclude visits without condition, procedure, or drug code + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + # Calculate length of stay + # admission.timestamp is the admit time (from the timestamp column) + # admission.dischtime is the discharge time (from attributes) + admit_time = admission.timestamp + discharge_time = datetime.strptime(admission.dischtime, + "%Y-%m-%d %H:%M:%S") + los_days = (discharge_time - admit_time).days + + # generate label + label = int(los_days > self.days) + + samples.append( + { + "visit_id": admission.hadm_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "los": label, + } + ) + # no cohort selection + return samples + + class LengthOfStayPredictionMIMIC4(BaseTask): """Task for predicting length of stay using MIMIC-IV dataset. @@ -475,6 +643,11 @@ def __call__(self, patient: Patient) -> List[Dict]: sample_dataset.stats() print(sample_dataset.samples[0] if sample_dataset.samples else "No samples") + task = LengthOfStayThresholdPredictionMIMIC3(3) + sample_dataset = base_dataset.set_task(task) + sample_dataset.stats() + print(sample_dataset.samples[0] if sample_dataset.samples else "No samples") + from pyhealth.datasets import MIMIC4Dataset base_dataset = MIMIC4Dataset( diff --git a/tests/core/test_data_loader_to_numpy_util.py b/tests/core/test_data_loader_to_numpy_util.py new file mode 100644 index 000000000..578b60c97 --- /dev/null +++ b/tests/core/test_data_loader_to_numpy_util.py @@ -0,0 +1,217 @@ +""" +Unit tests for DataLoaderToNumpy. + +Tests cover: +- Tensor/list/ndarray conversion to numpy +- Feature flattening behavior (1D, 2D, >2D) +- Batch padding and truncation +- Label processing +- Transform integration across batches +- Edge cases (inconsistent shapes, no padding, etc.) +""" + +import unittest + +import numpy as np +import torch + +from pyhealth.models.utils import DataLoaderToNumpy + + +class MockDataLoader: + """Synthetic PyHealth DataLoader for testing.""" + + def __init__(self, batches): + """Initialize with batches.""" + self.batches = batches + + def __iter__(self): + """Return iterator over batches.""" + return iter(self.batches) + + +class TestDataLoaderToNumpyToNumpy(unittest.TestCase): + """Tests _to_numpy static method.""" + + def test_tensor_conversion(self): + """Test conversion from torch.Tensor to numpy.""" + arr = DataLoaderToNumpy._to_numpy(torch.tensor([1, 2, 3])) + + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.dtype, np.float32) + + def test_list_conversion(self): + """Test conversion from Python list to numpy.""" + arr = DataLoaderToNumpy._to_numpy([1, 2, 3]) + + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.dtype, np.float32) + + def test_numpy_passthrough(self): + """Test numpy input is preserved with correct dtype.""" + arr = DataLoaderToNumpy._to_numpy(np.array([1, 2, 3])) + + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.dtype, np.float32) + + +class TestDataLoaderToNumpyFlatten(unittest.TestCase): + """Tests for feature flattening behavior.""" + + def setUp(self): + """Initialize converter for flattening tests.""" + self.converter = DataLoaderToNumpy(["a"], "y") + + def test_1d_to_2d(self): + """Test that 1D arrays are reshaped to 2D.""" + arr = np.array([1, 2, 3]) + + result = self.converter._flatten_feature(arr, "a") + self.assertEqual(result.shape, (3, 1)) + + def test_2d_unchanged(self): + """Test that 2D arrays persist.""" + arr = np.array([[1, 2], [3, 4]]) + + result = self.converter._flatten_feature(arr, "a") + self.assertEqual(result.shape, (2, 2)) + + def test_3d_flatten(self): + """Test that > 2D arrays are flattened correctly.""" + arr = np.ones((2, 3, 4)) + + result = self.converter._flatten_feature(arr, "a") + self.assertEqual(result.shape, (2, 12)) + + +class TestDataLoaderToNumpyPadding(unittest.TestCase): + """Tests for padding and truncation across batches.""" + + def test_padding_enabled(self): + """Test padding when batch width is smaller than expected.""" + converter = DataLoaderToNumpy(["a"], "y", pad_batches=True) + + batch1 = {"a": np.array([[1, 2, 3]]), "y": [0]} + batch2 = {"a": np.array([[4, 5]]), "y": [1]} + + converter._process_features(batch1) + converter._fitted = True + + result = converter._process_features(batch2) + + self.assertEqual(result.shape[1], 3) + self.assertEqual(result[0, 2], 0.0) + + def test_truncation_enabled(self): + """Test truncation when batch width exceeds expected.""" + converter = DataLoaderToNumpy(["a"], "y", pad_batches=True) + + batch1 = {"a": np.array([[1, 2]]), "y": [0]} + batch2 = {"a": np.array([[3, 4, 5]]), "y": [1]} + + converter._process_features(batch1) + converter._fitted = True + + result = converter._process_features(batch2) + + self.assertEqual(result.shape[1], 2) + self.assertTrue(np.array_equal(result[0], [3, 4])) + + def test_padding_disabled_error(self): + """Test error when inconsistent widths and padding disabled.""" + converter = DataLoaderToNumpy(["a"], "y", pad_batches=False) + + batch1 = {"a": np.array([[1, 2]]), "y": [0]} + batch2 = {"a": np.array([[3, 4, 5]]), "y": [1]} + + converter._process_features(batch1) + converter._fitted = True + + with self.assertRaises(ValueError): + converter._process_features(batch2) + + +class TestDataLoaderToNumpyFeatures(unittest.TestCase): + """Tests feature concatenation.""" + + def test_multiple_features_concat(self): + """Test concatenation of multiple feature arrays.""" + converter = DataLoaderToNumpy(["a", "b"], "y") + + batch = { + "a": np.array([[1, 2]]), + "b": np.array([[3, 4, 5]]), + "y": [0], + } + + result = converter._process_features(batch) + + self.assertEqual(result.shape, (1, 5)) + self.assertTrue(np.array_equal(result[0], [1, 2, 3, 4, 5])) + + +class TestDataLoaderToNumpyLabels(unittest.TestCase): + """Tests label processing.""" + + def test_labels_flatten(self): + """Test that labels are flattened to 1D.""" + converter = DataLoaderToNumpy(["a"], "y") + + batch = {"a": [[1]], "y": [[1], [2], [3]]} + + result = converter._process_labels(batch) + + self.assertEqual(result.shape, (3,)) + self.assertTrue(np.array_equal(result, [1, 2, 3])) + + +class TestDataLoaderToNumpyTransform(unittest.TestCase): + """Tests for transform method.""" + + def test_basic_transform(self): + """Test end-to-end transform with consistent batches.""" + converter = DataLoaderToNumpy(["a", "b"], "y") + + dataloader = MockDataLoader([ + {"a": [[1, 2]], "b": [[3, 4]], "y": [0]}, + {"a": [[5, 6]], "b": [[7, 8]], "y": [1]}, + ]) + + x, y = converter.transform(dataloader) + + self.assertEqual(x.shape, (2, 4)) + self.assertEqual(y.shape, (2,)) + self.assertTrue(np.array_equal(y, [0, 1])) + + def test_padding_across_batches(self): + """Test padding across multiple batches.""" + converter = DataLoaderToNumpy(["a"], "y", pad_batches=True) + + dataloader = MockDataLoader([ + {"a": [[1, 2, 3]], "y": [0]}, + {"a": [[4, 5]], "y": [1]}, + ]) + + x, y = converter.transform(dataloader) + + self.assertEqual(x.shape, (2, 3)) + self.assertEqual(x[1, 2], 0.0) + + def test_multiple_batches_concat(self): + """Test concatenation of batches.""" + converter = DataLoaderToNumpy(["a"], "y") + + dataloader = MockDataLoader([ + {"a": [[1], [2]], "y": [0, 1]}, + {"a": [[3], [4]], "y": [2, 3]}, + ]) + + x, y = converter.transform(dataloader) + + self.assertEqual(x.shape, (4, 1)) + self.assertEqual(y.shape, (4,)) + self.assertTrue(np.array_equal(y, [0, 1, 2, 3])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic3_threshold_los.py b/tests/core/test_mimic3_threshold_los.py new file mode 100644 index 000000000..593419143 --- /dev/null +++ b/tests/core/test_mimic3_threshold_los.py @@ -0,0 +1,208 @@ +""" +Unit tests for LengthOfStayThresholdPredictionMIMIC3 task. + +Tests: +- Sample generation +- Label generation +- Feature extraction (conditions, procedures, drugs) +- Edge cases (empty visits, minimal samples, minor patients) +""" + +import unittest +from datetime import datetime, timedelta +from typing import Any, List, Optional, Tuple + +from pyhealth.tasks import LengthOfStayThresholdPredictionMIMIC3 + + +class MockEvent: + """A simple container for event attributes for testing purposes. This is + useful to avoid using real MIMIC 3 data for testing. + + Example: + >>> e = MockEvent(hadm_id="v1", icd9_code="c1") + >>> e.hadm_id + 'v1' + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize a mock event. + + Args: + **kwargs: event attributes (hadm_id, icd9_code, ...). + + Example: + >>> e = MockEvent(hadm_id="v1") + >>> e.hadm_id + 'v1' + """ + self.__dict__.update(kwargs) + + +class MockPatient: + """A mock patient object for testing purposes. + + Example: + >>> patient = MockPatient() + >>> admissions = patient.get_events("admissions") + >>> len(admissions) > 0 + True + """ + + def __init__( + self, + los_days: float = 4.0, + include_features: bool = True, + minor: bool = False) -> None: + """Initializes a patient for testing purposes. + + Args: + los_days (float): Length of stay in days. Defaults to 4.0 days. + include_features (bool): Whether to include conditions/procedures/drugs. + Defaults to True. + minor (bool): True if the patient's age is below 18, + False otherwise. Defaults to False. + """ + self.patient_id = "p0" + + admission_time = datetime(2020, 1, 1) + discharge_time = datetime(2020, 1, 1) + timedelta(days = los_days) + + date_of_birth_year = 2009 if minor else 2000 + + self.events = { + "admissions": [ + MockEvent( + hadm_id = "v0", + timestamp = admission_time, + dischtime = discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ], + "patients": [ + MockEvent(dob = f"{date_of_birth_year}-01-01 00:00:00") + ], + "diagnoses_icd": ( + [MockEvent(hadm_id = "v0", + icd9_code = "c1")] if include_features else [] + ), + "procedures_icd": ( + [MockEvent(hadm_id = "v0", + icd9_code = "p1")] if include_features else [] + ), + "prescriptions": ( + [MockEvent(hadm_id = "v0", ndc = "d1")] if include_features else [] + ), + } + + def get_events(self, + event_type: str, + filters: Optional[List[Tuple[str, str, Any]]] = None + ) -> List[MockEvent]: + """Return events by type with filters if defined. + + Args: + event_type (str): Type of event. + filters (list, optional): Filtering rules (will only return events with + matching these attributes). Defaults to None. + + Returns: + list: events + + Example: + >>> patient = MockPatient() + >>> patient.get_events("diagnoses_icd") + [MockEvent] + """ + events = self.events.get(event_type, []) + + # Filter, given filters have been provided (defaults to None) + if filters: + key, _, value = filters[0] + return [e for e in events if getattr(e, key) == value] + + return events + + +class TestLengthOfStayThresholdPrediction(unittest.TestCase): + """Unit tests for LengthOfStayThresholdPredictionMIMIC3 task. + + Tests: + - Sample generation + - Label generation + - Feature extraction (conditions, procedures, drugs) + - Edge cases (empty visits, minimal samples, minor patients) + """ + + def setUp(self): + """Sets up task for testing.""" + self.task = LengthOfStayThresholdPredictionMIMIC3(days=3) + + def test_generates_samples(self): + """Tests sample generation. + + Example: + >>> samples = task(patient) + >>> len(samples) > 0 + True + """ + patient = MockPatient(los_days=5) + samples = self.task(patient) + + self.assertGreater(len(samples), 0) + + def test_label_thresholding(self): + """Test binary LOS label generation.""" + patient_below_threshold = MockPatient(los_days=2) + patient_beyond_threshold = MockPatient(los_days=5) + + below_threshold_label = self.task(patient_below_threshold)[0]["los"] + beyond_threshold_label = self.task(patient_beyond_threshold)[0]["los"] + + # Verify labels reflect whether the patients stay was beyond three days + self.assertEqual(below_threshold_label, 0) + self.assertEqual(beyond_threshold_label, 1) + + def test_feature_integrity(self): + """Tests feature extraction.""" + patient = MockPatient(los_days=4) + sample = self.task(patient)[0] + + self.assertIn("conditions", sample) + self.assertIn("procedures", sample) + self.assertIn("drugs", sample) + + self.assertGreater(len(sample["conditions"]), 0) + self.assertGreater(len(sample["procedures"]), 0) + self.assertGreater(len(sample["drugs"]), 0) + + self.assertEqual(sample["los"], 1) + + def test_empty_features(self): + """Test samples with missing features aren't excluded.""" + patient = MockPatient(include_features=False) + samples = self.task(patient) + + self.assertEqual(len(samples), 0) + + def test_exclude_minors(self): + """Test that minor patients are excluded if exclude minors flag is enabled.""" + task = LengthOfStayThresholdPredictionMIMIC3(days=3, exclude_minors=True) + + patient = MockPatient(minor=True) + samples = task(patient) + + # Verify there are no samples since the only sample was from a minor patient + self.assertEqual(len(samples), 0) + + def test_single_sample(self): + """Test task behavior with a single valid admission.""" + patient = MockPatient(los_days=10) + samples = self.task(patient) + + self.assertEqual(len(samples), 1) + self.assertIn("los", samples[0]) + self.assertEqual(samples[0]["los"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_random_forest.py b/tests/core/test_random_forest.py new file mode 100644 index 000000000..d5d7f9047 --- /dev/null +++ b/tests/core/test_random_forest.py @@ -0,0 +1,360 @@ +""" +Unit tests for the RandomForest model. + +Tests cover: +- Model initialization +- Classification and regression behavior +- Forward pass (fitted vs. unfitted) +- Model parameters +- Expected no backward/gradients +""" +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import RandomForest + + +class TestRandomForest(unittest.TestCase): + """Test cases for the Random Forest model.""" + + def setUp(self): + """Set up minimal synthetic test data.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "procedures": [1.0, 2.0, 3.5, 4], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": [5.0, 2.0, 3.5, 4], + "label": 1, + }, + ] + + self.batch_size = len(self.samples) + self.num_samples = len(self.samples) + self.num_classes = len(set(sample["label"] for sample in self.samples)) + + self.input_schema = { + "conditions": "sequence", + "procedures": "tensor", + } + self.output_schema = {"label": "binary"} + self.regression_output_schema = {"label": "regression"} + + self.dataset = create_sample_dataset( + samples = self.samples, + input_schema = self.input_schema, + output_schema = self.output_schema, + dataset_name = "test", + ) + self.dataset_regression = create_sample_dataset( + samples = self.samples, + input_schema = self.input_schema, + output_schema = self.regression_output_schema, + dataset_name = "test", + ) + + # models are instantiated within individual test cases + + def test_model_initialization_classification(self): + """Test model initialization using default parameters""" + + model = RandomForest( + dataset = self.dataset + ) + + expected = { + "pad_batches": True, + "f1_average": "macro", + "n_estimators": 100, + "criterion": "gini", + "max_depth": None, + "min_samples_split": 2, + "min_samples_leaf": 1, + "min_weight_fraction_leaf": 0.0, + "max_features": "sqrt", + "max_leaf_nodes": None, + "min_impurity_decrease": 0.0, + "bootstrap": True, + "oob_score": False, + "n_jobs": None, + "random_state": 42, + "verbose": 0, + "warm_start": False, + "class_weight": None, + "ccp_alpha": 0.0, + "max_samples": None, + "monotonic_cst": None + } + + self.assertIsInstance(model, RandomForest) + with self.assertRaises(AttributeError): + # Random forests do not have embeddings + self.assertEqual(model.embedding_dim, None) + self.assertEqual(len(model.feature_keys), 2) + self.assertIn("conditions", model.feature_keys) + self.assertIn("procedures", model.feature_keys) + self.assertEqual(model.label_key, "label") + self.assertEqual(model.get_params(), expected) + + def test_model_initialization_regression(self): + """Test regression model initialization with default params""" + + model = RandomForest( + dataset = self.dataset_regression + ) + + expected = { + "pad_batches": True, + "f1_average": "macro", + "n_estimators": 100, + "criterion": "friedman_mse", + "max_depth": None, + "min_samples_split": 2, + "min_samples_leaf": 1, + "min_weight_fraction_leaf": 0.0, + "max_features": "sqrt", + "max_leaf_nodes": None, + "min_impurity_decrease": 0.0, + "bootstrap": True, + "oob_score": False, + "n_jobs": None, + "random_state": 42, + "verbose": 0, + "warm_start": False, + "ccp_alpha": 0.0, + "max_samples": None, + "monotonic_cst": None + } + + self.assertIsInstance(model, RandomForest) + with self.assertRaises(AttributeError): + # Random forests do not have embeddings + self.assertEqual(model.embedding_dim, None) + self.assertEqual(len(model.feature_keys), 2) + self.assertIn("conditions", model.feature_keys) + self.assertIn("procedures", model.feature_keys) + self.assertEqual(model.label_key, "label") + self.assertEqual(model.get_params(), expected) + + def test_model_initialization_non_defaults(self): + """Test regression model initialization with custom parameters""" + + model = RandomForest( + dataset = self.dataset_regression, + pad_batches = False, + f1_average = "weighted", + n_estimators = 50, + criterion = "poisson", + max_depth = 10, + min_samples_split = 5, + min_samples_leaf = 2, + min_weight_fraction_leaf = 0.1, + max_features = "log2", + max_leaf_nodes = 20, + min_impurity_decrease = 0.01, + bootstrap = False, + oob_score = True, + n_jobs = 1, + random_state = 123, + verbose = 1, + warm_start = True, + class_weight = "balanced", + ccp_alpha = 0.05, + max_samples = 2 + ) + + expected = { + "pad_batches": False, + "f1_average": "weighted", + "n_estimators": 50, + "criterion": "poisson", + "max_depth": 10, + "min_samples_split": 5, + "min_samples_leaf": 2, + "min_weight_fraction_leaf": 0.1, + "max_features": "log2", + "max_leaf_nodes": 20, + "min_impurity_decrease": 0.01, + "bootstrap": False, + "oob_score": True, + "n_jobs": 1, + "random_state": 123, + "verbose": 1, + "warm_start": True, + "ccp_alpha": 0.05, + "max_samples": 2, + "monotonic_cst": None + } + + self.assertIsInstance(model, RandomForest) + with self.assertRaises(AttributeError): + # Random forests do not have embeddings + self.assertEqual(model.embedding_dim, None) + self.assertEqual(len(model.feature_keys), 2) + self.assertIn("conditions", model.feature_keys) + self.assertIn("procedures", model.feature_keys) + self.assertEqual(model.label_key, "label") + self.assertEqual(model.get_params(), expected) + + def test_model_initialization_invalid_f1_average(self): + with self.assertRaises(ValueError): + RandomForest(dataset = self.dataset, f1_average = "any") + + def test_model_initialization_invalid_criterion(self): + with self.assertRaises(ValueError): + RandomForest(dataset = self.dataset, criterion = "invalid") + + def test_forward_unfitted_classification(self): + """ + Tests forward call on the Random Forest model tasked for classification without + having called fit, which is expected to elicit a runtime error + """ + model = RandomForest( + dataset = self.dataset) + + loader = get_dataloader(self.dataset, batch_size = self.batch_size, + shuffle = False) + batch = next(iter(loader)) + + # Calling forward without fitting should cause a RuntimeError to be raised + with self.assertRaises(RuntimeError): + model(**batch) + + def test_forward_unfitted_regression(self): + """ + Tests forward call on the Random Forest model tasked for regression without + having called fit, which is expected to elicit a runtime error + """ + model = RandomForest( + dataset = self.dataset_regression + ) + + loader = get_dataloader(self.dataset_regression, batch_size = self.batch_size, + shuffle = False) + batch = next(iter(loader)) + + # Calling forward without fitting should cause a RuntimeError to be raised + with self.assertRaises(RuntimeError): + model(**batch) + + def test_forward_fitted_classification(self): + """ + Tests forward call on the Random Forest model tasked for classification with + fit called + """ + model = RandomForest( + dataset = self.dataset) + + loader = get_dataloader( + self.dataset, batch_size = self.batch_size, shuffle = False) + + model.fit(loader) + + batch = next(iter(loader)) + out = model(**batch) + + # Verify all expected keys exist in forward output: loss, y_prob, y_true, logit + self.assertIn("loss", out) + self.assertIn("y_prob", out) + self.assertIn("y_true", out) + self.assertIn("logit", out) + + # Verify output shapes + self.assertEqual(out["logit"].shape, (self.batch_size, self.num_classes)) + self.assertEqual(out["y_prob"].shape, (self.batch_size, self.num_classes)) + self.assertEqual(out["y_true"].shape[0], self.batch_size) + + def test_forward_fitted_regression(self): + """ + Tests forward call on the Random Forest model tasked for regression with + fit called + """ + model = RandomForest( + dataset = self.dataset_regression + ) + + loader = get_dataloader( + self.dataset_regression, batch_size = self.batch_size, shuffle = False) + + model.fit(loader) + + batch = next(iter(loader)) + out = model(**batch) + + # Verify all expected keys exist in forward output: loss, y_prob, y_true, logit + self.assertIn("loss", out) + self.assertIn("y_prob", out) + self.assertIn("y_true", out) + self.assertIn("logit", out) + + # Verify output shapes + self.assertEqual(out["logit"].shape, (self.batch_size, 1)) + self.assertEqual(out["y_prob"].shape, (self.batch_size, 1)) + self.assertEqual(out["y_true"].shape[0], self.batch_size) + + def test_backward(self): + """Test that the Random Forest model does not support a backward pass as + expected.""" + + model = RandomForest(dataset = self.dataset) + + # Create data loader + train_loader = get_dataloader( + self.dataset, batch_size = self.batch_size, shuffle = True) + data_batch = next(iter(train_loader)) + + # Fit and forward pass + model.fit(train_loader) + out = model(**data_batch) + + # Backward is not applicable for random forest models. Verify a runtime error + # is raised when a backward pass is attempted + with self.assertRaises(RuntimeError): + out["loss"].backward() + + # Check that no gradients are required + has_gradient = False + for param in model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertFalse(has_gradient) + + def test_fit_and_predict_pipeline(self): + """Test fitting and inference produce probability outputs.""" + model = RandomForest(dataset=self.dataset) + loader = get_dataloader(self.dataset, batch_size = self.batch_size, shuffle=False) + + model.fit(loader) + batch = next(iter(loader)) + + with torch.no_grad(): + out = model(**batch) + + self.assertIn("y_prob", out) + + def test_multiple_batches(self): + """Test model processes all batches and returns outputs per sample.""" + model = RandomForest(dataset=self.dataset) + loader = get_dataloader(self.dataset, batch_size = 1, shuffle=False) + + model.fit(loader) + + outputs = [] + for batch in loader: + outputs.append(model(**batch)["y_prob"]) + + self.assertEqual(len(outputs), self.num_samples) + + +if __name__ == "__main__": + unittest.main()