Skip to content

Commit 368c554

Browse files
committed
Types for passing generators
1 parent 13ae71f commit 368c554

File tree

1 file changed

+88
-1
lines changed

1 file changed

+88
-1
lines changed

person_story.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Story generators for the CC HIC OMOP schema."""
22
import datetime as dt
3-
from typing import Callable, Generator, Optional, Union, cast
3+
from typing import Callable, Generator, List, Optional, Union, cast
44

55
import numpy as np
66
from mimesis import Generic
@@ -10,6 +10,18 @@
1010
SqlRow = dict[str, SqlValue]
1111
SrcStatsResult = list[SqlRow]
1212
SrcStats = dict[str, SrcStatsResult]
13+
from typing import TypedDict, Callable, List, Optional, Dict, Union
14+
15+
class SingularMeasurement(TypedDict):
16+
values: list[float] | list[int]
17+
18+
class GroupedMeasurements(TypedDict):
19+
datetime: dt.datetime
20+
person_id: int
21+
visit_occurrence_id: int
22+
concepts: Dict[int, tuple[int, int]]
23+
values: Dict[int, List[SingularMeasurement]]
24+
generators: Dict[int, Callable[[int|float], int|float]]
1325

1426

1527
def random_normal(mean: float, std_dev: Optional[float] = None) -> float:
@@ -216,6 +228,81 @@ def populate_blood_pressure_values(
216228
events.append(("measurement", diastolic))
217229
return events
218230

231+
def populate_group_measurement(
232+
person: SqlRow,
233+
visit_occurrence: SqlRow,
234+
src_stats: SrcStats,
235+
) -> List[tuple[str, SqlRow]]:
236+
"""Generate events for a visit occurrence, at a given rate with a given generator.
237+
238+
This is a utility function for generating multiple rows for one of the "event"
239+
tables (measurements, observation, etc.).
240+
"""
241+
242+
Systolic_blood_pressure_by_Noninvasive = 21492239
243+
Diastolic_blood_pressure_by_Noninvasive = 21492240
244+
measurement_type_concept_id = 32817 # EHR measurement
245+
avg_systolic = 114.236842
246+
avg_diastolic = 74.447368
247+
avg_difference = avg_systolic - avg_diastolic
248+
unit_concept_id = 8876 # mmHg
249+
250+
def get_diastolic_from_systolic(systolic: List[float]) -> float:
251+
"""Estimate diastolic value from systolic value."""
252+
return [s - avg_difference for s in systolic]
253+
254+
def timeseries(length: int) -> float:
255+
"""Estimate diastolic value from systolic value."""
256+
return [0] * length
257+
258+
generators: dict[str, Callable[[int|float], int|float]] = {
259+
"timeseries": timeseries,
260+
"diastolic": get_diastolic_from_systolic
261+
}
262+
263+
m: GroupedMeasurements = {
264+
"concepts": {Systolic_blood_pressure_by_Noninvasive: (measurement_type_concept_id, unit_concept_id),
265+
Diastolic_blood_pressure_by_Noninvasive: (measurement_type_concept_id, unit_concept_id)},
266+
"values": {Systolic_blood_pressure_by_Noninvasive: [],
267+
Diastolic_blood_pressure_by_Noninvasive: []},
268+
"generators": {Systolic_blood_pressure_by_Noninvasive: generators["timeseries"],
269+
Diastolic_blood_pressure_by_Noninvasive: generators["diastolic"]},
270+
"datetime": dt.datetime.now(),
271+
"person_id": cast(int, person["person_id"]),
272+
"visit_occurrence_id": cast(int, visit_occurrence["visit_occurrence_id"]),
273+
}
274+
275+
m["values"][Systolic_blood_pressure_by_Noninvasive] = generators["timeseries"](10)
276+
m["values"][Diastolic_blood_pressure_by_Noninvasive] = generators["diastolic"](m["values"][Systolic_blood_pressure_by_Noninvasive])
277+
278+
def populate_values(
279+
event_datetime: dt.datetime,
280+
) -> dict[int, SqlRow]:
281+
282+
"""Generate two rows for the measurement table."""
283+
r: SqlRow = {
284+
"measurement_concept_id": m.concept_id,
285+
"person_id": m.person_id,
286+
"visit_occurrence_id": visit_occurrence_id,
287+
"measurement_datetime": event_datetime,
288+
"measurement_date": event_datetime.date(),
289+
"measurement_type_concept_id": m.type_concept_id,
290+
"unit_concept_id": m.unit_concept_id,
291+
"value_as_number": abs(random_normal(m.properties["average_value"], m.properties["stddev_value"])),
292+
}
293+
294+
return r
295+
296+
event_datetimes = random_event_times(10.0, visit_occurrence)
297+
events: list[tuple[str, SqlRow]] = []
298+
for event_datetime in sorted(event_datetimes):
299+
systolic, diastolic = populate_values(cast(int, person["person_id"]),
300+
cast(int, visit_occurrence["visit_occurrence_id"]),
301+
event_datetime)
302+
events.append(("measurement", systolic))
303+
events.append(("measurement", diastolic))
304+
return events
305+
219306
def generate(
220307
generic: Generic,
221308
src_stats: SrcStats,

0 commit comments

Comments
 (0)