Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions config_story.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,68 @@ src-stats:
m.visit_occurrence_id,
v.visit_start_datetime,
v.visit_end_datetime) subquery;
- name: bp_sys_relative_change_stats
#For each patient + visit + BP type, sort measurements in time,
#compute |xₜ − xₜ₋₁| / xₜ₋₁ for each step,
#then average those within that visit.
#Finally, average that visit-level metric across the population.
query: >
WITH bp AS (
SELECT
p.person_id,
p.gender_concept_id,
v.visit_occurrence_id,
EXTRACT(YEAR FROM v.visit_start_datetime) - p.year_of_birth AS age,
m.measurement_concept_id,
m.value_as_number,
m.measurement_datetime
FROM mimic.person p
JOIN mimic.visit_occurrence v
ON p.person_id = v.person_id
JOIN mimic.measurement m
ON m.visit_occurrence_id = v.visit_occurrence_id
WHERE m.measurement_concept_id IN (21492239) -- systolic
),
bp_with_lag AS (
SELECT
bp.*,
LAG(value_as_number) OVER (
PARTITION BY person_id, visit_occurrence_id, measurement_concept_id
ORDER BY measurement_datetime
) AS prev_value
FROM bp
),
visit_level_rel_var AS (
SELECT
person_id,
gender_concept_id,
age,
visit_occurrence_id,
measurement_concept_id,
-- average relative change within this visit
AVG( ABS(value_as_number - prev_value)
/ NULLIF(prev_value, 0) ) AS avg_rel_var
FROM bp_with_lag
WHERE prev_value IS NOT NULL -- skip first measurement in each visit
GROUP BY
person_id,
gender_concept_id,
age,
visit_occurrence_id,
measurement_concept_id
)
SELECT
gender_concept_id,
AVG(CASE WHEN measurement_concept_id = 21492239 AND age > 60
THEN avg_rel_var END)::float4 AS avg_over_60_systolic_rel_var,
AVG(CASE WHEN measurement_concept_id = 21492239 AND age <= 60
THEN avg_rel_var END)::float4 AS avg_under_60_systolic_rel_var,
STDDEV(CASE WHEN measurement_concept_id = 21492239 AND age > 60
THEN avg_rel_var END)::float4 AS stddev_over_60_systolic_rel_var,
STDDEV(CASE WHEN measurement_concept_id = 21492239 AND age <= 60
THEN avg_rel_var END)::float4 AS stddev_under_60_systolic_rel_var
FROM visit_level_rel_var
GROUP BY gender_concept_id;
tables:
person:
num_rows_per_pass: 0
Expand Down
197 changes: 125 additions & 72 deletions person_story.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Story generators for the CC HIC OMOP schema."""
import datetime as dt
from typing import Callable, Generator, Optional, Union, cast

from sqlsynthgen.utils import generate_time_series
import numpy as np
from mimesis import Generic
import random
Expand All @@ -24,9 +24,10 @@ def random_normal(mean: float, std_dev: Optional[float] = None) -> float:


def gen_death(
generic: Generic, person: SqlRow, src_stats: SrcStats
generic: Generic, person: SqlRow, src_stats: SrcStats
) -> Optional[tuple[str, SqlRow]]:
"""Generate a row for the death table."""

def with_probability(p: float) -> bool:
"""Return True with probability p (0 ≤ p ≤ 1)."""
return random.random() < p
Expand All @@ -36,7 +37,8 @@ def with_probability(p: float) -> bool:
else:
avg_age_at_death_days = src_stats["age_at_death"][0]["average_age_years"] * 365
std_dev_age_at_death_days = src_stats["age_at_death"][0]["stddev_age_years"] * 365
age_at_death_days = abs(random_normal(cast(float,avg_age_at_death_days), cast(float,std_dev_age_at_death_days)))
age_at_death_days = abs(
random_normal(cast(float, avg_age_at_death_days), cast(float, std_dev_age_at_death_days)))
death_datetime = cast(dt.datetime, person["birth_datetime"]) + dt.timedelta(
days=age_at_death_days)
return "death", {
Expand All @@ -45,35 +47,36 @@ def with_probability(p: float) -> bool:
"death_date": death_datetime.date(),
}


def gen_visit_occurrence(
person: SqlRow, death: Optional[SqlRow], src_stats: SrcStats
person: SqlRow, death: Optional[SqlRow], src_stats: SrcStats
) -> tuple[str, SqlRow]:
"""Generate a row for the visit_occurrence table."""
age_days_at_visit_start = abs(
random_normal(
cast(float, 63*365), cast(float, 13*365)
cast(float, 63 * 365), cast(float, 13 * 365)
)
)
if person["gender_concept_id"] == 8532:
age_days_at_visit_start = abs(
random_normal(
cast(float, src_stats["age_first_admission"][0]["average_age_years"]*365),
cast(float, src_stats["age_first_admission"][0]["stddev_age_years"]*365)
cast(float, src_stats["age_first_admission"][0]["average_age_years"] * 365),
cast(float, src_stats["age_first_admission"][0]["stddev_age_years"] * 365)
)
)
if person["gender_concept_id"] == 8507:
age_days_at_visit_start = abs(
random_normal(
cast(float, src_stats["age_first_admission"][1]["average_age_years"]*365),
cast(float, src_stats["age_first_admission"][1]["stddev_age_years"]*365)
cast(float, src_stats["age_first_admission"][1]["average_age_years"] * 365),
cast(float, src_stats["age_first_admission"][1]["stddev_age_years"] * 365)
)
)
visit_start_datetime = cast(dt.datetime, person["birth_datetime"]) + dt.timedelta(
days=age_days_at_visit_start
)
visit_length_hours = abs(
random_normal(
cast(float, src_stats["visit_duration"][0]["average_hours"]),
cast(float, src_stats["visit_duration"][0]["average_hours"]),
cast(float, src_stats["visit_duration"][0]["stddev_hours"])
# cast(float, 6), cast(float, 29*24)
)
Expand Down Expand Up @@ -114,15 +117,15 @@ def random_event_times(avg_rate: float, visit_occurrence: SqlRow) -> list[dt.dat


def gen_events( # pylint: disable=too-many-arguments
generic: Generic,
avg_rate: float,
visit_occurrence: SqlRow,
person: SqlRow,
generator_function: Callable[
[Generic, int, int, dt.datetime, SrcStats], Optional[SqlRow]
],
table_name: str,
src_stats: SrcStats,
generic: Generic,
avg_rate: float,
visit_occurrence: SqlRow,
person: SqlRow,
generator_function: Callable[
[Generic, int, int, dt.datetime, SrcStats], Optional[SqlRow]
],
table_name: str,
src_stats: SrcStats,
) -> list[tuple[str, SqlRow]]:
"""Generate events for a visit occurrence, at a given rate with a given generator.

Expand All @@ -143,82 +146,127 @@ def gen_events( # pylint: disable=too-many-arguments
events.append((table_name, event))
return events


def gen_blood_pressure_events( # pylint: disable=too-many-arguments
avg_rate: float,
visit_occurrence: SqlRow,
person: SqlRow,
src_stats: SrcStats,
avg_rate: float,
visit_occurrence: SqlRow,
person: SqlRow,
src_stats: SrcStats,
) -> list[tuple[str, SqlRow]]:
"""Generate events for a visit occurrence, at a given rate with a given generator.

This is a utility function for generating multiple rows for one of the "event"
tables (measurements, observation, etc.).
"""

def populate_blood_pressure_values(
person_id: int,
visit_occurrence_id: int,
event_datetime: dt.datetime,
def generate_paired_measurement(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@myyong, I've done some abstraction in this function. I didn't go too far, but it may be one step toward #184.

person_id: int,
visit_occurrence_id: int,
event_datetime: dt.datetime,
values: tuple[float, float],
measurement_concept_id: tuple[int, int],
measurement_type_concept_ids: int,
unit_concept_id: int,
unit_source_value: str,
) -> tuple[SqlRow, SqlRow]:

Systolic_blood_pressure_by_Noninvasive = 21492239
Diastolic_blood_pressure_by_Noninvasive = 21492240
measurement_type_concept_id = 32817 # EHR measurement
avg_systolic = 114.236842
avg_diastolic = 74.447368
avg_difference = avg_systolic - avg_diastolic
unit_concept_id = 8876 # mmHg

gender = cast(int, person["gender_concept_id"])
if gender == 8507:
systolic_value = random_normal(src_stats["bp_profile"][0]["average_under_60_systolic"],src_stats["bp_profile"][0]["stddev_under_60_systolic"])
diastolic_value = src_stats["bp_profile"][0]["average_systolic_diastolic_difference"] + systolic_value
elif gender == 8532:
systolic_value = random_normal(src_stats["bp_profile"][1]["average_under_60_systolic"],src_stats["bp_profile"][1]["stddev_under_60_systolic"])
diastolic_value = src_stats["bp_profile"][1]["average_systolic_diastolic_difference"] + systolic_value
else:
systolic_value = avg_systolic
diastolic_value = avg_diastolic

### This can be abastracted to generate any number of set of measurements
"""Generate two rows for the measurement table."""
systolic: SqlRow = {
"measurement_concept_id": cast(int, Systolic_blood_pressure_by_Noninvasive),
measurement1: SqlRow = {
"measurement_concept_id": cast(int, measurement_concept_id[0]),
"person_id": person_id,
"visit_occurrence_id": visit_occurrence_id,
"measurement_datetime": event_datetime,
"measurement_date": event_datetime.date(),
"measurement_type_concept_id": measurement_type_concept_id,
"measurement_type_concept_id": measurement_type_concept_ids,
"unit_concept_id": unit_concept_id,
"unit_source_value": "mmHg",
"value_as_number": systolic_value,
"unit_source_value": unit_source_value,
"value_as_number": values[0],
}

diastolic: SqlRow = {
"measurement_concept_id": cast(int, Diastolic_blood_pressure_by_Noninvasive),
measurement2: SqlRow = {
"measurement_concept_id": cast(int, measurement_concept_id[1]),
"person_id": person_id,
"visit_occurrence_id": visit_occurrence_id,
"measurement_datetime": event_datetime,
"measurement_date": event_datetime.date(),
"measurement_type_concept_id": measurement_type_concept_id,
"measurement_type_concept_id": measurement_type_concept_ids,
"unit_concept_id": unit_concept_id,
"unit_source_value": "mmHg",
"value_as_number": diastolic_value,
"unit_source_value": unit_source_value,
"value_as_number": values[1],
}
return systolic, diastolic
return measurement1, measurement2

event_datetimes = random_event_times(avg_rate, visit_occurrence)

if len(event_datetimes) == 0:
return []

# can we get this from the data?
sys_bp_non_invasive_concept_id = 21492239
dias_bp_non_invasive_concept_id = 21492240
measurement_type_concept_id = 32817 # EHR measurement
unit_source_value = "mmHg"
unit_concept_id = 8876 # mmHg

gender = cast(int, person["gender_concept_id"])
age = (cast(dt.datetime, visit_occurrence["visit_start_datetime"]) - cast(dt.datetime,
person["birth_datetime"])).days / 365.25

main_key = 'bp_profile'
relative_change_key = 'bp_sys_relative_change_stats'
if age < 60:
key_mean = 'average_under_60_systolic'
key_std = 'stddev_under_60_systolic'

key_epsilon_mean = 'avg_under_60_systolic_rel_var'
key_epsilon_std = 'stddev_under_60_systolic_rel_var'


else:
key_mean = 'average_over_60_systolic'
key_std = 'stddev_over_60_systolic'

key_epsilon_mean = 'avg_over_60_systolic_rel_var'
key_epsilon_std = 'stddev_over_60_systolic_rel_var'

if gender == 8507:
index_gender = 0
else:
index_gender = 1

sample_epsilon = np.random.normal(src_stats[relative_change_key][index_gender][key_epsilon_mean],
src_stats[relative_change_key][index_gender][key_epsilon_std], 1)[0]

systolic_value = np.round(generate_time_series(len(event_datetimes), 'random_walk',
{'mean': src_stats[main_key][index_gender][key_mean],
'std': src_stats[main_key][index_gender][key_std],
'epsilon_std': sample_epsilon, 'drift': 0}))

# diastolic value is calculated based on systolic value plus the average difference extrated from data
# we add some variation to the difference between systolic and diastolic
diastolic_value = np.round(systolic_value - random_normal(src_stats[main_key][index_gender]['average_systolic_diastolic_difference'],
src_stats[main_key][index_gender][
"average_systolic_diastolic_difference"] * 0.1) )

events: list[tuple[str, SqlRow]] = []
for event_datetime in sorted(event_datetimes):
systolic, diastolic = populate_blood_pressure_values(cast(int, person["person_id"]),
cast(int, visit_occurrence["visit_occurrence_id"]),
event_datetime)
events.append(("measurement", systolic))
events.append(("measurement", diastolic))
for index, event_datetime in enumerate(sorted(event_datetimes)):
systolic_dict, diastolic_dict = generate_paired_measurement(cast(int, person["person_id"]),
cast(int, visit_occurrence["visit_occurrence_id"]),
event_datetime,
(systolic_value[index], diastolic_value[index]),
(sys_bp_non_invasive_concept_id,
dias_bp_non_invasive_concept_id),
measurement_type_concept_id, unit_concept_id,
unit_source_value)
events.append(("measurement", systolic_dict)),
events.append(("measurement", diastolic_dict))
return events


def generate(
generic: Generic,
src_stats: SrcStats,
generic: Generic,
src_stats: SrcStats,
) -> Generator[tuple[str, SqlRow], SqlRow, None]:
"""Yield all the data related to a single patient.

Expand All @@ -240,15 +288,20 @@ def generate(
death_row = (yield death) if death else None
visit_occurrence = yield gen_visit_occurrence(person, death_row, src_stats)

# abs to avoid negative rates due to random normal variation
# abs to avoid negative rates due to random normal variation
avg_rate = abs(random_normal(
src_stats["avg_measurements_per_visit_hour"][0]['avg_measurements_per_hour'],
src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'] ))
src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'])
)

print(f"Generating blood pressure events at an average rate of {avg_rate} per hour. Using IID sampling.")
for event in gen_blood_pressure_events(
avg_rate,
visit_occurrence,
person,
src_stats,
avg_rate,
visit_occurrence,
person,
src_stats,
):
yield event
# Yield each measurement event if is not empty dictionary
if len(event) > 0:
yield event
Loading