Skip to content

Commit 2dfdacf

Browse files
committed
Refactor function which chooses column providers
1 parent c598b19 commit 2dfdacf

File tree

2 files changed

+82
-29
lines changed

2 files changed

+82
-29
lines changed

sqlsynthgen/make.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _get_default_generator(tables_module: ModuleType, column: Any) -> RowGenerat
185185
variable_names,
186186
generator_function,
187187
generator_arguments,
188-
) = _get_mimesis_function_for_colum(column)
188+
) = _get_provider_for_column(column)
189189

190190
return RowGenerator(
191191
primary_key=column.primary_key,
@@ -196,7 +196,7 @@ def _get_default_generator(tables_module: ModuleType, column: Any) -> RowGenerat
196196
)
197197

198198

199-
def _get_mimesis_function_for_colum(column: Any) -> Tuple[List[str], str, List[str]]:
199+
def _get_provider_for_column(column: Any) -> Tuple[List[str], str, List[str]]:
200200
"""
201201
Get a default Mimesis provider and its arguments for a SQL column type.
202202
@@ -209,37 +209,35 @@ def _get_mimesis_function_for_colum(column: Any) -> Tuple[List[str], str, List[s
209209
"""
210210
variable_names: List[str] = [column.name]
211211
generator_arguments: List[str] = []
212-
generator_function: str = ""
213212

214213
column_type = type(column.type)
215214
column_size: Optional[int] = getattr(column.type, "length", None)
216215

217-
# ToDo Add tests and then add issubclass for all of these
218-
# sqlalchemy.dialects.mysql.types.INTEGER
219-
if column_type == sqltypes.BigInteger:
220-
generator_function = "generic.numeric.integer_number"
221-
if column_type == sqltypes.Integer or issubclass(column_type, sqltypes.Integer):
222-
generator_function = "generic.numeric.integer_number"
223-
elif column_type == sqltypes.Boolean:
224-
generator_function = "generic.development.boolean"
225-
elif column_type == sqltypes.Date:
226-
generator_function = "generic.datetime.date"
227-
elif column_type == sqltypes.DateTime:
228-
generator_function = "generic.datetime.datetime"
229-
elif column_type in {sqltypes.Float, sqltypes.Numeric}:
230-
generator_function = "generic.numeric.float_number"
231-
elif column_type == sqltypes.Integer:
232-
generator_function = "generic.numeric.integer_number"
233-
elif column_type == sqltypes.LargeBinary:
234-
generator_function = "generic.bytes_provider.bytes"
235-
elif column_type in {sqltypes.String, sqltypes.Text} and column_size is None:
236-
generator_function = "generic.text.color"
237-
elif column_type in {sqltypes.String, sqltypes.Text} and column_size is not None:
238-
generator_function = "generic.person.password"
239-
generator_arguments.append(str(column_size))
240-
else:
216+
mapping = {
217+
(sqltypes.Integer, False): "generic.numeric.integer_number",
218+
(sqltypes.Boolean, False): "generic.development.boolean",
219+
(sqltypes.Date, False): "generic.datetime.date",
220+
(sqltypes.DateTime, False): "generic.datetime.datetime",
221+
(sqltypes.Numeric, False): "generic.numeric.float_number",
222+
(sqltypes.LargeBinary, False): "generic.bytes_provider.bytes",
223+
(sqltypes.String, False): "generic.text.color",
224+
(sqltypes.String, True): "generic.person.password",
225+
}
226+
227+
generator_function = mapping.get((column_type, column_size is not None), None)
228+
229+
if not generator_function:
230+
for key, value in mapping.items():
231+
if issubclass(column_type, key[0]) and key[1] == (column_size is not None):
232+
generator_function = value
233+
break
234+
235+
if not generator_function:
241236
raise ValueError(f"Unsupported SQLAlchemy type: {column_type}")
242237

238+
if column_size:
239+
generator_arguments.append(str(column_size))
240+
243241
return variable_names, generator_function, generator_arguments
244242

245243

tests/test_make.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
import yaml
1010
from pydantic import PostgresDsn
1111
from pydantic.tools import parse_obj_as
12-
13-
from sqlsynthgen.make import make_src_stats, make_table_generators, make_tables_file
12+
from sqlalchemy import BigInteger, Column, String
13+
from sqlalchemy.dialects.mysql.types import INTEGER
14+
15+
from sqlsynthgen.make import (
16+
_get_provider_for_column,
17+
make_src_stats,
18+
make_table_generators,
19+
make_tables_file,
20+
)
1421
from tests.examples import example_orm
1522
from tests.utils import RequiresDBTestCase, SSGTestCase, get_test_settings
1623

@@ -119,6 +126,54 @@ def test_make_generators_force_overwrite(
119126

120127
self.assertEqual(expected, actual)
121128

129+
def test__get_provider_for_column(self) -> None:
130+
"""Test the _get_provider_for_column function."""
131+
132+
# Simple case
133+
(
134+
variable_name,
135+
generator_function,
136+
generator_arguments,
137+
) = _get_provider_for_column(Column("myint", BigInteger))
138+
self.assertListEqual(
139+
variable_name,
140+
["myint"],
141+
)
142+
self.assertEqual(
143+
generator_function,
144+
"generic.numeric.integer_number",
145+
)
146+
self.assertEqual(
147+
generator_arguments,
148+
[],
149+
)
150+
151+
# Column type from another dialect
152+
_, generator_function, __ = _get_provider_for_column(Column("myint", INTEGER))
153+
self.assertEqual(
154+
generator_function,
155+
"generic.numeric.integer_number",
156+
)
157+
158+
# Text value with length
159+
(
160+
variable_name,
161+
generator_function,
162+
generator_arguments,
163+
) = _get_provider_for_column(Column("mystring", String(100)))
164+
self.assertEqual(
165+
variable_name,
166+
["mystring"],
167+
)
168+
self.assertEqual(
169+
generator_function,
170+
"generic.person.password",
171+
)
172+
self.assertEqual(
173+
generator_arguments,
174+
["100"],
175+
)
176+
122177

123178
class TestMakeTables(SSGTestCase):
124179
"""Test the make_tables function."""

0 commit comments

Comments
 (0)