Skip to content

Commit 9b5592a

Browse files
committed
Log row counts for create-data
1 parent 059cdeb commit 9b5592a

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

sqlsynthgen/create.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Functions and classes to create and populate the target database."""
2-
from typing import Any, Generator, Mapping, Sequence, Tuple
2+
from typing import Any, Generator, Mapping, Optional, Sequence, Tuple
33

44
from sqlalchemy import Connection, insert
55
from sqlalchemy.exc import IntegrityError
@@ -10,6 +10,7 @@
1010
from sqlsynthgen.utils import create_db_engine, get_sync_engine, logger
1111

1212
Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]
13+
RowCounts = dict[str, int]
1314

1415

1516
def create_db_tables(metadata: MetaData) -> None:
@@ -57,7 +58,7 @@ def create_db_data(
5758
table_generator_dict: Mapping[str, TableGenerator],
5859
story_generator_list: Sequence[Mapping[str, Any]],
5960
num_passes: int,
60-
) -> None:
61+
) -> RowCounts:
6162
"""Connect to a database and populate it with data."""
6263
settings = get_settings()
6364
dst_dsn: str = settings.dst_dsn or ""
@@ -67,22 +68,26 @@ def create_db_data(
6768
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
6869
)
6970

71+
row_counts: RowCounts = {}
7072
with dst_engine.connect() as dst_conn:
7173
for _ in range(num_passes):
72-
populate(
74+
row_counts = populate(
7375
dst_conn,
7476
sorted_tables,
7577
table_generator_dict,
7678
story_generator_list,
79+
row_counts,
7780
)
81+
return row_counts
7882

7983

8084
def _populate_story(
8185
story: Story,
8286
table_dict: Mapping[str, Table],
8387
table_generator_dict: Mapping[str, TableGenerator],
8488
dst_conn: Connection,
85-
) -> None:
89+
row_counts: RowCounts,
90+
) -> RowCounts:
8691
"""Write to the database all the rows created by the given story."""
8792
# Loop over the rows generated by the story, insert them into their
8893
# respective tables. Ideally this would say
@@ -111,19 +116,24 @@ def _populate_story(
111116
else:
112117
return_values = {}
113118
final_values = {**insert_values, **return_values}
119+
row_counts[table_name] = row_counts.get(table_name, 0) + 1
114120
table_name, provided_values = story.send(final_values)
115121
except StopIteration:
116122
# The story has finished, it has no more rows to generate
117123
pass
124+
return row_counts
118125

119126

120127
def populate(
121128
dst_conn: Connection,
122129
tables: Sequence[Table],
123130
table_generator_dict: Mapping[str, TableGenerator],
124131
story_generator_list: Sequence[Mapping[str, Any]],
125-
) -> None:
132+
row_counts: Optional[RowCounts] = None,
133+
) -> RowCounts:
126134
"""Populate a database schema with synthetic data."""
135+
if row_counts is None:
136+
row_counts = {}
127137
table_dict = {table.name: table for table in tables}
128138
# Generate stories
129139
# Each story generator returns a python generator (an unfortunate naming clash with
@@ -143,7 +153,9 @@ def populate(
143153
# Run the inserts for each story within a transaction.
144154
logger.debug("Generating data for story %s", name)
145155
with dst_conn.begin():
146-
_populate_story(story, table_dict, table_generator_dict, dst_conn)
156+
row_counts = _populate_story(
157+
story, table_dict, table_generator_dict, dst_conn, row_counts
158+
)
147159

148160
# Generate individual rows, table by table.
149161
for table in tables:
@@ -160,3 +172,5 @@ def populate(
160172
for _ in range(table_generator.num_rows_per_pass):
161173
stmt = insert(table).values(table_generator(dst_conn))
162174
dst_conn.execute(stmt)
175+
row_counts[table.name] = row_counts.get(table.name, 0) + 1
176+
return row_counts

sqlsynthgen/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,15 @@ def create_data(
9696
orm_metadata = get_orm_metadata(orm_module, tables_config)
9797
table_generator_dict = ssg_module.table_generator_dict
9898
story_generator_list = ssg_module.story_generator_list
99-
create_db_data(
99+
row_counts = create_db_data(
100100
orm_metadata.sorted_tables,
101101
table_generator_dict,
102102
story_generator_list,
103103
num_passes,
104104
)
105105
logger.debug("Data created in %s passes.", num_passes)
106+
for table_name, row_count in row_counts.items():
107+
logger.debug("%s: %s rows created", table_name, row_count)
106108

107109

108110
@app.command()

tests/test_create.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ def test_create_db_data(
3434
) -> None:
3535
"""Test the generate function."""
3636
mock_get_settings.return_value = get_test_settings()
37+
mock_populate.return_value = {}
3738

3839
num_passes = 23
39-
create_db_data([], {}, [], num_passes)
40+
row_counts = create_db_data([], {}, [], num_passes)
4041

4142
self.assertEqual(len(mock_populate.call_args_list), num_passes)
43+
self.assertEqual(row_counts, {})
4244
mock_create_engine.assert_called()
4345

4446
@patch("sqlsynthgen.create.get_settings")
@@ -62,13 +64,15 @@ def test_populate(self) -> None:
6264

6365
def story() -> Generator[Tuple[str, dict], None, None]:
6466
"""Mock story."""
65-
yield "table_name", {}
67+
yield table_name, {}
6668

6769
def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
6870
"""A function that returns mock stories."""
6971
return story()
7072

71-
for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
73+
for num_stories_per_pass, num_rows_per_pass, num_initial_rows in itt.product(
74+
[0, 2], [0, 3], [0, 17]
75+
):
7276
with patch("sqlsynthgen.create.insert") as mock_insert:
7377
mock_values = mock_insert.return_value.values
7478
mock_dst_conn = MagicMock(spec=Connection)
@@ -78,9 +82,10 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
7882
mock_gen = MagicMock(spec=TableGenerator)
7983
mock_gen.num_rows_per_pass = num_rows_per_pass
8084
mock_gen.return_value = {}
85+
row_counts = (
86+
{table_name: num_initial_rows} if num_initial_rows > 0 else {}
87+
)
8188

82-
tables: list[Table] = [mock_table]
83-
row_generators: dict[str, TableGenerator] = {table_name: mock_gen}
8489
story_generators: list[dict[str, Any]] = (
8590
[
8691
{
@@ -92,13 +97,21 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
9297
if num_stories_per_pass > 0
9398
else []
9499
)
95-
populate(
100+
row_counts = populate(
96101
mock_dst_conn,
97-
tables,
98-
row_generators,
102+
[mock_table],
103+
{table_name: mock_gen},
99104
story_generators,
105+
row_counts,
100106
)
101107

108+
expected_row_count = (
109+
num_stories_per_pass + num_rows_per_pass + num_initial_rows
110+
)
111+
self.assertEqual(
112+
row_counts,
113+
{table_name: expected_row_count} if expected_row_count > 0 else {},
114+
)
102115
self.assertListEqual(
103116
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
104117
mock_gen.call_args_list,
@@ -135,7 +148,8 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
135148
"three": mock_gen_three,
136149
}
137150

138-
populate(mock_dst_conn, tables, row_generators, [])
151+
row_counts = populate(mock_dst_conn, tables, row_generators, [], {})
152+
self.assertEqual(row_counts, {"two": 1, "three": 1})
139153
self.assertListEqual(
140154
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
141155
)
@@ -207,4 +221,4 @@ def my_story() -> Story:
207221

208222
with engine.connect() as conn:
209223
with conn.begin():
210-
_populate_story(my_story(), dict(self.metadata.tables), {}, conn)
224+
_populate_story(my_story(), dict(self.metadata.tables), {}, conn, {})

tests/test_functional.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,14 @@ def test_workflow_maximal_args(self) -> None:
362362
"Generating data for table unique_constraint_test2\n"
363363
"Generating data for table test_entity\n"
364364
"Generating data for table hospital_visit\n"
365-
"Data created in 2 passes.\n",
365+
"Data created in 2 passes.\n"
366+
f"person: {2*(3+1+2+2)} rows created\n"
367+
f"hospital_visit: {2*(2*2+3)} rows created\n"
368+
"data_type_test: 2 rows created\n"
369+
"no_pk_test: 2 rows created\n"
370+
"unique_constraint_test: 2 rows created\n"
371+
"unique_constraint_test2: 2 rows created\n"
372+
"test_entity: 2 rows created\n",
366373
completed_process.stdout.decode("utf-8"),
367374
)
368375

0 commit comments

Comments
 (0)