|
1 | 1 | """Story generators for the CC HIC OMOP schema.""" |
2 | 2 | import datetime as dt |
3 | | -from typing import Callable, Generator, Optional, Union, cast |
| 3 | +from typing import Callable, Generator, List, Optional, Union, cast |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | from mimesis import Generic |
|
10 | 10 | SqlRow = dict[str, SqlValue] |
11 | 11 | SrcStatsResult = list[SqlRow] |
12 | 12 | SrcStats = dict[str, SrcStatsResult] |
| 13 | +from typing import TypedDict, Callable, List, Optional, Dict, Union |
| 14 | + |
| 15 | +class SingularMeasurement(TypedDict): |
| 16 | + values: list[float] | list[int] |
| 17 | + |
| 18 | +class GroupedMeasurements(TypedDict): |
| 19 | + datetime: dt.datetime |
| 20 | + person_id: int |
| 21 | + visit_occurrence_id: int |
| 22 | + concepts: Dict[int, tuple[int, int]] |
| 23 | + values: Dict[int, List[SingularMeasurement]] |
| 24 | + generators: Dict[int, Callable[[int|float], int|float]] |
13 | 25 |
|
14 | 26 |
|
15 | 27 | def random_normal(mean: float, std_dev: Optional[float] = None) -> float: |
@@ -216,6 +228,81 @@ def populate_blood_pressure_values( |
216 | 228 | events.append(("measurement", diastolic)) |
217 | 229 | return events |
218 | 230 |
|
| 231 | +def populate_group_measurement( |
| 232 | + person: SqlRow, |
| 233 | + visit_occurrence: SqlRow, |
| 234 | + src_stats: SrcStats, |
| 235 | +) -> List[tuple[str, SqlRow]]: |
| 236 | + """Generate events for a visit occurrence, at a given rate with a given generator. |
| 237 | +
|
| 238 | + This is a utility function for generating multiple rows for one of the "event" |
| 239 | + tables (measurements, observation, etc.). |
| 240 | + """ |
| 241 | + |
| 242 | + Systolic_blood_pressure_by_Noninvasive = 21492239 |
| 243 | + Diastolic_blood_pressure_by_Noninvasive = 21492240 |
| 244 | + measurement_type_concept_id = 32817 # EHR measurement |
| 245 | + avg_systolic = 114.236842 |
| 246 | + avg_diastolic = 74.447368 |
| 247 | + avg_difference = avg_systolic - avg_diastolic |
| 248 | + unit_concept_id = 8876 # mmHg |
| 249 | + |
| 250 | + def get_diastolic_from_systolic(systolic: List[float]) -> float: |
| 251 | + """Estimate diastolic value from systolic value.""" |
| 252 | + return [s - avg_difference for s in systolic] |
| 253 | + |
| 254 | + def timeseries(length: int) -> float: |
| 255 | + """Estimate diastolic value from systolic value.""" |
| 256 | + return [0] * length |
| 257 | + |
| 258 | + generators: dict[str, Callable[[int|float], int|float]] = { |
| 259 | + "timeseries": timeseries, |
| 260 | + "diastolic": get_diastolic_from_systolic |
| 261 | + } |
| 262 | + |
| 263 | + m: GroupedMeasurements = { |
| 264 | + "concepts": {Systolic_blood_pressure_by_Noninvasive: (measurement_type_concept_id, unit_concept_id), |
| 265 | + Diastolic_blood_pressure_by_Noninvasive: (measurement_type_concept_id, unit_concept_id)}, |
| 266 | + "values": {Systolic_blood_pressure_by_Noninvasive: [], |
| 267 | + Diastolic_blood_pressure_by_Noninvasive: []}, |
| 268 | + "generators": {Systolic_blood_pressure_by_Noninvasive: generators["timeseries"], |
| 269 | + Diastolic_blood_pressure_by_Noninvasive: generators["diastolic"]}, |
| 270 | + "datetime": dt.datetime.now(), |
| 271 | + "person_id": cast(int, person["person_id"]), |
| 272 | + "visit_occurrence_id": cast(int, visit_occurrence["visit_occurrence_id"]), |
| 273 | + } |
| 274 | + |
| 275 | + m["values"][Systolic_blood_pressure_by_Noninvasive] = generators["timeseries"](10) |
| 276 | + m["values"][Diastolic_blood_pressure_by_Noninvasive] = generators["diastolic"](m["values"][Systolic_blood_pressure_by_Noninvasive]) |
| 277 | + |
| 278 | + def populate_values( |
| 279 | + event_datetime: dt.datetime, |
| 280 | + ) -> dict[int, SqlRow]: |
| 281 | + |
| 282 | + """Generate two rows for the measurement table.""" |
| 283 | + r: SqlRow = { |
| 284 | + "measurement_concept_id": m.concept_id, |
| 285 | + "person_id": m.person_id, |
| 286 | + "visit_occurrence_id": visit_occurrence_id, |
| 287 | + "measurement_datetime": event_datetime, |
| 288 | + "measurement_date": event_datetime.date(), |
| 289 | + "measurement_type_concept_id": m.type_concept_id, |
| 290 | + "unit_concept_id": m.unit_concept_id, |
| 291 | + "value_as_number": abs(random_normal(m.properties["average_value"], m.properties["stddev_value"])), |
| 292 | + } |
| 293 | + |
| 294 | + return r |
| 295 | + |
| 296 | + event_datetimes = random_event_times(10.0, visit_occurrence) |
| 297 | + events: list[tuple[str, SqlRow]] = [] |
| 298 | + for event_datetime in sorted(event_datetimes): |
| 299 | + systolic, diastolic = populate_values(cast(int, person["person_id"]), |
| 300 | + cast(int, visit_occurrence["visit_occurrence_id"]), |
| 301 | + event_datetime) |
| 302 | + events.append(("measurement", systolic)) |
| 303 | + events.append(("measurement", diastolic)) |
| 304 | + return events |
| 305 | + |
219 | 306 | def generate( |
220 | 307 | generic: Generic, |
221 | 308 | src_stats: SrcStats, |
|
0 commit comments