11"""Functions and classes to create and populate the target database."""
2- from typing import Any , Generator , Mapping , Optional , Sequence , Tuple
2+ from collections import Counter
3+ from typing import Any , Generator , Mapping , Sequence , Tuple
34
45from sqlalchemy import Connection , insert
56from sqlalchemy .exc import IntegrityError
1011from sqlsynthgen .utils import create_db_engine , get_sync_engine , logger
1112
1213Story = Generator [Tuple [str , dict [str , Any ]], dict [str , Any ], None ]
13- RowCounts = dict [str , int ]
14+ RowCounts = Counter [str ]
1415
1516
1617def create_db_tables (metadata : MetaData ) -> None :
@@ -68,15 +69,14 @@ def create_db_data(
6869 create_db_engine (dst_dsn , schema_name = settings .dst_schema )
6970 )
7071
71- row_counts : RowCounts = {}
72+ row_counts : Counter [ str ] = Counter ()
7273 with dst_engine .connect () as dst_conn :
7374 for _ in range (num_passes ):
74- row_counts = populate (
75+ row_counts + = populate (
7576 dst_conn ,
7677 sorted_tables ,
7778 table_generator_dict ,
7879 story_generator_list ,
79- row_counts ,
8080 )
8181 return row_counts
8282
@@ -86,13 +86,13 @@ def _populate_story(
8686 table_dict : Mapping [str , Table ],
8787 table_generator_dict : Mapping [str , TableGenerator ],
8888 dst_conn : Connection ,
89- row_counts : RowCounts ,
9089) -> RowCounts :
9190 """Write to the database all the rows created by the given story."""
9291 # Loop over the rows generated by the story, insert them into their
9392 # respective tables. Ideally this would say
9493 # `for table_name, provided_values in story:`
9594 # but we have to loop more manually to be able to use the `send` function.
95+ row_counts : Counter [str ] = Counter ()
9696 try :
9797 table_name , provided_values = next (story )
9898 while True :
@@ -129,11 +129,9 @@ def populate(
129129 tables : Sequence [Table ],
130130 table_generator_dict : Mapping [str , TableGenerator ],
131131 story_generator_list : Sequence [Mapping [str , Any ]],
132- row_counts : Optional [RowCounts ] = None ,
133132) -> RowCounts :
134133 """Populate a database schema with synthetic data."""
135- if row_counts is None :
136- row_counts = {}
134+ row_counts : Counter [str ] = Counter ()
137135 table_dict = {table .name : table for table in tables }
138136 # Generate stories
139137 # Each story generator returns a python generator (an unfortunate naming clash with
@@ -153,8 +151,8 @@ def populate(
153151 # Run the inserts for each story within a transaction.
154152 logger .debug ('Generating data for story "%s".' , name )
155153 with dst_conn .begin ():
156- row_counts = _populate_story (
157- story , table_dict , table_generator_dict , dst_conn , row_counts
154+ row_counts + = _populate_story (
155+ story , table_dict , table_generator_dict , dst_conn
158156 )
159157
160158 # Generate individual rows, table by table.
0 commit comments