Skip to content

Commit 79011c7

Browse files
committed
Swap Dict with Counter for returning RowCounts
1 parent d55c0a8 commit 79011c7

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

sqlsynthgen/create.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Functions and classes to create and populate the target database."""
2-
from typing import Any, Generator, Mapping, Optional, Sequence, Tuple
2+
from collections import Counter
3+
from typing import Any, Generator, Mapping, Sequence, Tuple
34

45
from sqlalchemy import Connection, insert
56
from sqlalchemy.exc import IntegrityError
@@ -10,7 +11,7 @@
1011
from sqlsynthgen.utils import create_db_engine, get_sync_engine, logger
1112

1213
Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]
13-
RowCounts = dict[str, int]
14+
RowCounts = Counter[str]
1415

1516

1617
def create_db_tables(metadata: MetaData) -> None:
@@ -68,15 +69,14 @@ def create_db_data(
6869
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
6970
)
7071

71-
row_counts: RowCounts = {}
72+
row_counts: Counter[str] = Counter()
7273
with dst_engine.connect() as dst_conn:
7374
for _ in range(num_passes):
74-
row_counts = populate(
75+
row_counts += populate(
7576
dst_conn,
7677
sorted_tables,
7778
table_generator_dict,
7879
story_generator_list,
79-
row_counts,
8080
)
8181
return row_counts
8282

@@ -86,13 +86,13 @@ def _populate_story(
8686
table_dict: Mapping[str, Table],
8787
table_generator_dict: Mapping[str, TableGenerator],
8888
dst_conn: Connection,
89-
row_counts: RowCounts,
9089
) -> RowCounts:
9190
"""Write to the database all the rows created by the given story."""
9291
# Loop over the rows generated by the story, insert them into their
9392
# respective tables. Ideally this would say
9493
# `for table_name, provided_values in story:`
9594
# but we have to loop more manually to be able to use the `send` function.
95+
row_counts: Counter[str] = Counter()
9696
try:
9797
table_name, provided_values = next(story)
9898
while True:
@@ -129,11 +129,9 @@ def populate(
129129
tables: Sequence[Table],
130130
table_generator_dict: Mapping[str, TableGenerator],
131131
story_generator_list: Sequence[Mapping[str, Any]],
132-
row_counts: Optional[RowCounts] = None,
133132
) -> RowCounts:
134133
"""Populate a database schema with synthetic data."""
135-
if row_counts is None:
136-
row_counts = {}
134+
row_counts: Counter[str] = Counter()
137135
table_dict = {table.name: table for table in tables}
138136
# Generate stories
139137
# Each story generator returns a python generator (an unfortunate naming clash with
@@ -153,8 +151,8 @@ def populate(
153151
# Run the inserts for each story within a transaction.
154152
logger.debug('Generating data for story "%s".', name)
155153
with dst_conn.begin():
156-
row_counts = _populate_story(
157-
story, table_dict, table_generator_dict, dst_conn, row_counts
154+
row_counts += _populate_story(
155+
story, table_dict, table_generator_dict, dst_conn
158156
)
159157

160158
# Generate individual rows, table by table.

tests/test_create.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the create module."""
22
import itertools as itt
3+
from collections import Counter
34
from pathlib import Path
45
from typing import Any, Generator, Tuple
56
from unittest.mock import MagicMock, call, patch
@@ -82,7 +83,7 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
8283
mock_gen = MagicMock(spec=TableGenerator)
8384
mock_gen.num_rows_per_pass = num_rows_per_pass
8485
mock_gen.return_value = {}
85-
row_counts = (
86+
row_counts = Counter(
8687
{table_name: num_initial_rows} if num_initial_rows > 0 else {}
8788
)
8889

@@ -97,20 +98,23 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
9798
if num_stories_per_pass > 0
9899
else []
99100
)
100-
row_counts = populate(
101+
row_counts += populate(
101102
mock_dst_conn,
102103
[mock_table],
103104
{table_name: mock_gen},
104105
story_generators,
105-
row_counts,
106106
)
107107

108108
expected_row_count = (
109109
num_stories_per_pass + num_rows_per_pass + num_initial_rows
110110
)
111111
self.assertEqual(
112+
Counter(
113+
{table_name: expected_row_count}
114+
if expected_row_count > 0
115+
else {}
116+
),
112117
row_counts,
113-
{table_name: expected_row_count} if expected_row_count > 0 else {},
114118
)
115119
self.assertListEqual(
116120
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
@@ -148,7 +152,7 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
148152
"three": mock_gen_three,
149153
}
150154

151-
row_counts = populate(mock_dst_conn, tables, row_generators, [], {})
155+
row_counts = populate(mock_dst_conn, tables, row_generators, [])
152156
self.assertEqual(row_counts, {"two": 1, "three": 1})
153157
self.assertListEqual(
154158
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
@@ -221,4 +225,4 @@ def my_story() -> Story:
221225

222226
with engine.connect() as conn:
223227
with conn.begin():
224-
_populate_story(my_story(), dict(self.metadata.tables), {}, conn, {})
228+
_populate_story(my_story(), dict(self.metadata.tables), {}, conn)

tests/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_create_data(
188188
[
189189
call("Creating data."),
190190
call("Data created in %s %s.", 1, "pass"),
191-
call("%s: %s %s created", "a", "row.", 1),
191+
call("%s: %s %s created.", "a", 1, "row"),
192192
]
193193
)
194194

0 commit comments

Comments
 (0)