diff --git a/config_story.yaml b/config_story.yaml index 6693543..912bc65 100644 --- a/config_story.yaml +++ b/config_story.yaml @@ -103,6 +103,68 @@ src-stats: m.visit_occurrence_id, v.visit_start_datetime, v.visit_end_datetime) subquery; + - name: bp_sys_relative_change_stats + #For each patient + visit + BP type, sort measurements in time, + #compute |xₜ − xₜ₋₁| / xₜ₋₁ for each step, + #then average those within that visit. + #Finally, average that visit-level metric across the population. + query: > + WITH bp AS ( + SELECT + p.person_id, + p.gender_concept_id, + v.visit_occurrence_id, + EXTRACT(YEAR FROM v.visit_start_datetime) - p.year_of_birth AS age, + m.measurement_concept_id, + m.value_as_number, + m.measurement_datetime + FROM mimic.person p + JOIN mimic.visit_occurrence v + ON p.person_id = v.person_id + JOIN mimic.measurement m + ON m.visit_occurrence_id = v.visit_occurrence_id + WHERE m.measurement_concept_id IN (21492239) -- systolic + ), + bp_with_lag AS ( + SELECT + bp.*, + LAG(value_as_number) OVER ( + PARTITION BY person_id, visit_occurrence_id, measurement_concept_id + ORDER BY measurement_datetime + ) AS prev_value + FROM bp + ), + visit_level_rel_var AS ( + SELECT + person_id, + gender_concept_id, + age, + visit_occurrence_id, + measurement_concept_id, + -- average relative change within this visit + AVG( ABS(value_as_number - prev_value) + / NULLIF(prev_value, 0) ) AS avg_rel_var + FROM bp_with_lag + WHERE prev_value IS NOT NULL -- skip first measurement in each visit + GROUP BY + person_id, + gender_concept_id, + age, + visit_occurrence_id, + measurement_concept_id + ) + SELECT + gender_concept_id, + AVG(CASE WHEN measurement_concept_id = 21492239 AND age > 60 + THEN avg_rel_var END)::float4 AS avg_over_60_systolic_rel_var, + AVG(CASE WHEN measurement_concept_id = 21492239 AND age <= 60 + THEN avg_rel_var END)::float4 AS avg_under_60_systolic_rel_var, + STDDEV(CASE WHEN measurement_concept_id = 21492239 AND age > 60 + THEN avg_rel_var END)::float4 AS stddev_over_60_systolic_rel_var, + STDDEV(CASE WHEN measurement_concept_id = 21492239 AND age <= 60 + THEN avg_rel_var END)::float4 AS stddev_under_60_systolic_rel_var + FROM visit_level_rel_var + GROUP BY gender_concept_id; tables: person: num_rows_per_pass: 0 diff --git a/person_story.py b/person_story.py index e5afd6f..176f04d 100644 --- a/person_story.py +++ b/person_story.py @@ -1,7 +1,7 @@ """Story generators for the CC HIC OMOP schema.""" import datetime as dt from typing import Callable, Generator, Optional, Union, cast - +from sqlsynthgen.utils import generate_time_series import numpy as np from mimesis import Generic import random @@ -24,9 +24,10 @@ def random_normal(mean: float, std_dev: Optional[float] = None) -> float: def gen_death( - generic: Generic, person: SqlRow, src_stats: SrcStats + generic: Generic, person: SqlRow, src_stats: SrcStats ) -> Optional[tuple[str, SqlRow]]: """Generate a row for the death table.""" + def with_probability(p: float) -> bool: """Return True with probability p (0 ≤ p ≤ 1).""" return random.random() < p @@ -36,7 +37,8 @@ def with_probability(p: float) -> bool: else: avg_age_at_death_days = src_stats["age_at_death"][0]["average_age_years"] * 365 std_dev_age_at_death_days = src_stats["age_at_death"][0]["stddev_age_years"] * 365 - age_at_death_days = abs(random_normal(cast(float,avg_age_at_death_days), cast(float,std_dev_age_at_death_days))) + age_at_death_days = abs( + random_normal(cast(float, avg_age_at_death_days), cast(float, std_dev_age_at_death_days))) death_datetime = cast(dt.datetime, person["birth_datetime"]) + dt.timedelta( days=age_at_death_days) return "death", { @@ -45,27 +47,28 @@ def with_probability(p: float) -> bool: "death_date": death_datetime.date(), } + def gen_visit_occurrence( - person: SqlRow, death: Optional[SqlRow], src_stats: SrcStats + person: SqlRow, death: Optional[SqlRow], src_stats: SrcStats ) -> tuple[str, SqlRow]: """Generate a row for the visit_occurrence table.""" age_days_at_visit_start = abs( random_normal( - cast(float, 63*365), cast(float, 13*365) + cast(float, 63 * 365), cast(float, 13 * 365) ) ) if person["gender_concept_id"] == 8532: age_days_at_visit_start = abs( random_normal( - cast(float, src_stats["age_first_admission"][0]["average_age_years"]*365), - cast(float, src_stats["age_first_admission"][0]["stddev_age_years"]*365) + cast(float, src_stats["age_first_admission"][0]["average_age_years"] * 365), + cast(float, src_stats["age_first_admission"][0]["stddev_age_years"] * 365) ) ) if person["gender_concept_id"] == 8507: age_days_at_visit_start = abs( random_normal( - cast(float, src_stats["age_first_admission"][1]["average_age_years"]*365), - cast(float, src_stats["age_first_admission"][1]["stddev_age_years"]*365) + cast(float, src_stats["age_first_admission"][1]["average_age_years"] * 365), + cast(float, src_stats["age_first_admission"][1]["stddev_age_years"] * 365) ) ) visit_start_datetime = cast(dt.datetime, person["birth_datetime"]) + dt.timedelta( @@ -73,7 +76,7 @@ def gen_visit_occurrence( ) visit_length_hours = abs( random_normal( - cast(float, src_stats["visit_duration"][0]["average_hours"]), + cast(float, src_stats["visit_duration"][0]["average_hours"]), cast(float, src_stats["visit_duration"][0]["stddev_hours"]) # cast(float, 6), cast(float, 29*24) ) @@ -114,15 +117,15 @@ def random_event_times(avg_rate: float, visit_occurrence: SqlRow) -> list[dt.dat def gen_events( # pylint: disable=too-many-arguments - generic: Generic, - avg_rate: float, - visit_occurrence: SqlRow, - person: SqlRow, - generator_function: Callable[ - [Generic, int, int, dt.datetime, SrcStats], Optional[SqlRow] - ], - table_name: str, - src_stats: SrcStats, + generic: Generic, + avg_rate: float, + visit_occurrence: SqlRow, + person: SqlRow, + generator_function: Callable[ + [Generic, int, int, dt.datetime, SrcStats], Optional[SqlRow] + ], + table_name: str, + src_stats: SrcStats, ) -> list[tuple[str, SqlRow]]: """Generate events for a visit occurrence, at a given rate with a given generator. @@ -143,11 +146,12 @@ def gen_events( # pylint: disable=too-many-arguments events.append((table_name, event)) return events + def gen_blood_pressure_events( # pylint: disable=too-many-arguments - avg_rate: float, - visit_occurrence: SqlRow, - person: SqlRow, - src_stats: SrcStats, + avg_rate: float, + visit_occurrence: SqlRow, + person: SqlRow, + src_stats: SrcStats, ) -> list[tuple[str, SqlRow]]: """Generate events for a visit occurrence, at a given rate with a given generator. @@ -155,70 +159,114 @@ def gen_blood_pressure_events( # pylint: disable=too-many-arguments tables (measurements, observation, etc.). """ - def populate_blood_pressure_values( - person_id: int, - visit_occurrence_id: int, - event_datetime: dt.datetime, + def generate_paired_measurement( + person_id: int, + visit_occurrence_id: int, + event_datetime: dt.datetime, + values: tuple[float, float], + measurement_concept_id: tuple[int, int], + measurement_type_concept_ids: int, + unit_concept_id: int, + unit_source_value: str, ) -> tuple[SqlRow, SqlRow]: - - Systolic_blood_pressure_by_Noninvasive = 21492239 - Diastolic_blood_pressure_by_Noninvasive = 21492240 - measurement_type_concept_id = 32817 # EHR measurement - avg_systolic = 114.236842 - avg_diastolic = 74.447368 - avg_difference = avg_systolic - avg_diastolic - unit_concept_id = 8876 # mmHg - - gender = cast(int, person["gender_concept_id"]) - if gender == 8507: - systolic_value = random_normal(src_stats["bp_profile"][0]["average_under_60_systolic"],src_stats["bp_profile"][0]["stddev_under_60_systolic"]) - diastolic_value = src_stats["bp_profile"][0]["average_systolic_diastolic_difference"] + systolic_value - elif gender == 8532: - systolic_value = random_normal(src_stats["bp_profile"][1]["average_under_60_systolic"],src_stats["bp_profile"][1]["stddev_under_60_systolic"]) - diastolic_value = src_stats["bp_profile"][1]["average_systolic_diastolic_difference"] + systolic_value - else: - systolic_value = avg_systolic - diastolic_value = avg_diastolic + ### This can be abastracted to generate any number of set of measurements """Generate two rows for the measurement table.""" - systolic: SqlRow = { - "measurement_concept_id": cast(int, Systolic_blood_pressure_by_Noninvasive), + measurement1: SqlRow = { + "measurement_concept_id": cast(int, measurement_concept_id[0]), "person_id": person_id, "visit_occurrence_id": visit_occurrence_id, "measurement_datetime": event_datetime, "measurement_date": event_datetime.date(), - "measurement_type_concept_id": measurement_type_concept_id, + "measurement_type_concept_id": measurement_type_concept_ids, "unit_concept_id": unit_concept_id, - "unit_source_value": "mmHg", - "value_as_number": systolic_value, + "unit_source_value": unit_source_value, + "value_as_number": values[0], } - diastolic: SqlRow = { - "measurement_concept_id": cast(int, Diastolic_blood_pressure_by_Noninvasive), + measurement2: SqlRow = { + "measurement_concept_id": cast(int, measurement_concept_id[1]), "person_id": person_id, "visit_occurrence_id": visit_occurrence_id, "measurement_datetime": event_datetime, "measurement_date": event_datetime.date(), - "measurement_type_concept_id": measurement_type_concept_id, + "measurement_type_concept_id": measurement_type_concept_ids, "unit_concept_id": unit_concept_id, - "unit_source_value": "mmHg", - "value_as_number": diastolic_value, + "unit_source_value": unit_source_value, + "value_as_number": values[1], } - return systolic, diastolic - + return measurement1, measurement2 + event_datetimes = random_event_times(avg_rate, visit_occurrence) + + if len(event_datetimes) == 0: + return [] + + # can we get this from the data? + sys_bp_non_invasive_concept_id = 21492239 + dias_bp_non_invasive_concept_id = 21492240 + measurement_type_concept_id = 32817 # EHR measurement + unit_source_value = "mmHg" + unit_concept_id = 8876 # mmHg + + gender = cast(int, person["gender_concept_id"]) + age = (cast(dt.datetime, visit_occurrence["visit_start_datetime"]) - cast(dt.datetime, + person["birth_datetime"])).days / 365.25 + + main_key = 'bp_profile' + relative_change_key = 'bp_sys_relative_change_stats' + if age < 60: + key_mean = 'average_under_60_systolic' + key_std = 'stddev_under_60_systolic' + + key_epsilon_mean = 'avg_under_60_systolic_rel_var' + key_epsilon_std = 'stddev_under_60_systolic_rel_var' + + + else: + key_mean = 'average_over_60_systolic' + key_std = 'stddev_over_60_systolic' + + key_epsilon_mean = 'avg_over_60_systolic_rel_var' + key_epsilon_std = 'stddev_over_60_systolic_rel_var' + + if gender == 8507: + index_gender = 0 + else: + index_gender = 1 + + sample_epsilon = np.random.normal(src_stats[relative_change_key][index_gender][key_epsilon_mean], + src_stats[relative_change_key][index_gender][key_epsilon_std], 1)[0] + + systolic_value = np.round(generate_time_series(len(event_datetimes), 'random_walk', + {'mean': src_stats[main_key][index_gender][key_mean], + 'std': src_stats[main_key][index_gender][key_std], + 'epsilon_std': sample_epsilon, 'drift': 0})) + + # diastolic value is calculated based on systolic value plus the average difference extrated from data + # we add some variation to the difference between systolic and diastolic + diastolic_value = np.round(systolic_value - random_normal(src_stats[main_key][index_gender]['average_systolic_diastolic_difference'], + src_stats[main_key][index_gender][ + "average_systolic_diastolic_difference"] * 0.1) ) + events: list[tuple[str, SqlRow]] = [] - for event_datetime in sorted(event_datetimes): - systolic, diastolic = populate_blood_pressure_values(cast(int, person["person_id"]), - cast(int, visit_occurrence["visit_occurrence_id"]), - event_datetime) - events.append(("measurement", systolic)) - events.append(("measurement", diastolic)) + for index, event_datetime in enumerate(sorted(event_datetimes)): + systolic_dict, diastolic_dict = generate_paired_measurement(cast(int, person["person_id"]), + cast(int, visit_occurrence["visit_occurrence_id"]), + event_datetime, + (systolic_value[index], diastolic_value[index]), + (sys_bp_non_invasive_concept_id, + dias_bp_non_invasive_concept_id), + measurement_type_concept_id, unit_concept_id, + unit_source_value) + events.append(("measurement", systolic_dict)), + events.append(("measurement", diastolic_dict)) return events + def generate( - generic: Generic, - src_stats: SrcStats, + generic: Generic, + src_stats: SrcStats, ) -> Generator[tuple[str, SqlRow], SqlRow, None]: """Yield all the data related to a single patient. @@ -240,15 +288,20 @@ def generate( death_row = (yield death) if death else None visit_occurrence = yield gen_visit_occurrence(person, death_row, src_stats) + # abs to avoid negative rates due to random normal variation # abs to avoid negative rates due to random normal variation avg_rate = abs(random_normal( src_stats["avg_measurements_per_visit_hour"][0]['avg_measurements_per_hour'], - src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'] )) + src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour']) + ) + print(f"Generating blood pressure events at an average rate of {avg_rate} per hour. Using IID sampling.") for event in gen_blood_pressure_events( - avg_rate, - visit_occurrence, - person, - src_stats, + avg_rate, + visit_occurrence, + person, + src_stats, ): - yield event + # Yield each measurement event if is not empty dictionary + if len(event) > 0: + yield event diff --git a/sqlsynthgen/utils.py b/sqlsynthgen/utils.py index ce16ec5..4bdbdbf 100644 --- a/sqlsynthgen/utils.py +++ b/sqlsynthgen/utils.py @@ -6,8 +6,8 @@ from importlib import import_module from pathlib import Path from types import ModuleType -from typing import Any, Final, Mapping, Optional, Union - +from typing import Any, Final, Mapping, Optional, Union, Literal, Dict +import numpy as np import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate @@ -179,3 +179,204 @@ def conf_logger(verbose: bool) -> None: logger.addHandler(stdout_handler) logger.addHandler(stderr_handler) + + +def generate_time_series( + N: int, + model_option: Literal["iid", "random_walk", "ar1"], + model_params: Dict[str, Any], +) -> np.ndarray: + """ + Generate a synthetic time series using one of three simple models. + + Parameters + ---------- + N : int + Number of time steps. + model_option : {"iid", "random_walk", "ar1"} + Which model to use. + model_params : dict + Dictionary of parameters. Expected keys: + + For all models: + - "mean": float + - "std": float + + For random_walk: + - "drift": float + - "epsilon_std": float + + For ar1: + - "mu": float + - "phi": float + - "epsilon_std": float + + random_state : int or None + Optional random seed. + + Returns + ------- + np.ndarray + Synthetic time series of length N. + """ + + rng = np.random.default_rng(None) + + # ---------------------------- + # MODEL 1: IID Gaussian + # ---------------------------- + if model_option == "iid": + return sample_iid_gaussian( + N=N, + mu=model_params["mean"], + sigma=model_params["std"], + rng=rng, + ) + + # ---------------------------- + # MODEL 2: Random Walk + # ---------------------------- + + x0: float = rng.normal( + loc=model_params["mean"], + scale=model_params["std"] + ) + if model_option == "random_walk": + required = ["drift", "epsilon_std"] + for key in required: + if key not in model_params: + raise KeyError(f"src_stats must contain '{key}' for random_walk") + + return random_walk_with_drift( + N=N, + x0=x0, + drift=model_params["drift"], + sigma_eps=model_params["epsilon_std"], + rng=rng, + ) + + # ---------------------------- + # MODEL 3: AR(1) + # ---------------------------- + if model_option == "ar1": + required = ["mu", "phi", "epsilon_std"] + for key in required: + if key not in model_params: + raise KeyError(f"src_stats must contain '{key}' for ar1") + + return ar1_process( + N=N, + x0=x0, + mu=model_params["mu"], + phi=model_params["phi"], + sigma_eps=model_params["epsilon_std"], + rng=rng, + ) + + # ---------------------------- + raise ValueError(f"Unknown model_option: {model_option!r}") + + +def sample_iid_gaussian( + N: int, + mu: float, + sigma: float, + rng: np.random.Generator +) -> np.ndarray: + """" + Generate an IID Gaussian time series. + + Parameters + ---------- + N : int + Length of the time series. + mu : float + Mean of the Gaussian. + sigma : float + Standard deviation of the Gaussian. + rng : np.random.Generator + Random number generator. + Returns + ------- + np.ndarray + Generated IID Gaussian time series of length N + """ + return rng.normal(loc=mu, scale=sigma, size=N) + + + +def random_walk_with_drift( + N: int, + x0: float, + drift: float, + sigma_eps: float, + rng: np.random.Generator +) -> np.ndarray: + """ + Generate a random walk time series with drift. + + Parameters + ---------- + N : int + Length of the time series. + x0 : float + Initial value of the time series. + drift : float + Drift term added at each time step. + sigma_eps : float + Standard deviation of the white noise. + rng : np.random.Generator + Random number generator. + Returns + ------- + np.ndarray + Generated random walk time series of length N + + + """ + x = np.empty(N) + x[0] = x0 + for t in range(1, N): + x[t] = x[t-1] + drift + rng.normal(0.0, sigma_eps)*100 + return x + + +def ar1_process( + N: int, + x0: float, + mu: float, + phi: float, + sigma_eps: float, + rng: np.random.Generator +) -> np.ndarray: + """ + Generate an AR(1) time series. + An AR(1) process is defined by the equation: + x[t] = mu + phi * (x[t-1] - mu) + eps[t] + where eps[t] ~ N(0, sigma_eps^2) + + Parameters + ---------- + N : int + Length of the time series. + x0 : float + Initial value of the time series. + mu : float + Mean of the AR(1) process. + phi : float + Autoregressive coefficient. + sigma_eps : float + Standard deviation of the white noise. + rng : np.random.Generator + Random number generator. + Returns + ------- + np.ndarray + Generated AR(1) time series of length N + """ + x = np.empty(N) + x[0] = x0 + for t in range(1, N): + eps = rng.normal(0.0, sigma_eps) + x[t] = mu + phi * (x[t-1] - mu) + eps + return x \ No newline at end of file