Skip to content

Commit 92f69cc

Browse files
committed
Add mariadb support
1 parent 2fb9462 commit 92f69cc

File tree

11 files changed

+94
-261
lines changed

11 files changed

+94
-261
lines changed

poetry.lock

Lines changed: 18 additions & 85 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jsonschema = "^4.17.3"
2424
sqlacodegen = "3.0.0rc1"
2525
asyncpg = "^0.27.0"
2626
greenlet = "^2.0.2"
27+
pymysql = "^1.1.0"
2728

2829
[tool.poetry.group.dev.dependencies]
2930
isort = "^5.10.1"

sqlsynthgen/create.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def create_db_tables(metadata: Any) -> Any:
1616
"""Create tables described by the sqlalchemy metadata object."""
1717
settings = get_settings()
1818

19-
engine = create_db_engine(settings.dst_postgres_dsn) # type: ignore
19+
engine = create_db_engine(settings.dst_dsn) # type: ignore
2020

2121
# Create schema, if necessary.
2222
if settings.dst_schema:
@@ -26,7 +26,7 @@ def create_db_tables(metadata: Any) -> Any:
2626

2727
# Recreate the engine, this time with a schema specified
2828
engine = create_db_engine(
29-
settings.dst_postgres_dsn, schema_name=schema_name # type: ignore
29+
settings.dst_dsn, schema_name=schema_name # type: ignore
3030
)
3131

3232
metadata.create_all(engine)
@@ -37,7 +37,7 @@ def create_db_vocab(vocab_dict: Dict[str, Any]) -> None:
3737
settings = get_settings()
3838

3939
dst_engine = create_db_engine(
40-
settings.dst_postgres_dsn, schema_name=settings.dst_schema # type: ignore
40+
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
4141
)
4242

4343
with dst_engine.connect() as dst_conn:
@@ -60,7 +60,7 @@ def create_db_data(
6060
settings = get_settings()
6161

6262
dst_engine = create_db_engine(
63-
settings.dst_postgres_dsn, schema_name=settings.dst_schema # type: ignore
63+
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
6464
)
6565

6666
with dst_engine.connect() as dst_conn:

sqlsynthgen/main.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import yaml
1111
from jsonschema.exceptions import ValidationError
1212
from jsonschema.validators import validate
13-
from pydantic import PostgresDsn
1413

1514
from sqlsynthgen.create import create_db_data, create_db_tables, create_db_vocab
1615
from sqlsynthgen.make import make_src_stats, make_table_generators, make_tables_file
@@ -35,12 +34,12 @@ def _check_file_non_existence(file_path: Path) -> None:
3534
sys.exit(1)
3635

3736

38-
def _get_src_postgres_dsn(settings: Settings) -> PostgresDsn:
37+
def _get_src_postgres_dsn(settings: Settings) -> str:
3938
"""Return the source DB Postgres DSN.
4039
4140
Check that source db details have been set. Exit with error message if not.
4241
"""
43-
if (src_dsn := settings.src_postgres_dsn) is None:
42+
if (src_dsn := settings.src_dsn) is None:
4443
typer.echo("Missing source database connection details.", err=True)
4544
sys.exit(1)
4645
return src_dsn
@@ -180,7 +179,7 @@ def make_stats(
180179
config = read_yaml_file(config_file) if config_file is not None else {}
181180

182181
settings = get_settings()
183-
src_dsn: PostgresDsn = _get_src_postgres_dsn(settings)
182+
src_dsn: str = _get_src_postgres_dsn(settings)
184183

185184
src_stats = asyncio.get_event_loop().run_until_complete(
186185
make_src_stats(src_dsn, config, settings.src_schema)
@@ -211,7 +210,7 @@ def make_tables(
211210
_check_file_non_existence(orm_file_path)
212211

213212
settings = get_settings()
214-
src_dsn: PostgresDsn = _get_src_postgres_dsn(settings)
213+
src_dsn: str = _get_src_postgres_dsn(settings)
215214

216215
content = make_tables_file(src_dsn, settings.src_schema)
217216
orm_file_path.write_text(content, encoding="utf-8")

sqlsynthgen/make.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from black import FileMode, format_str
1414
from jinja2 import Environment, FileSystemLoader, Template
1515
from mimesis.providers.base import BaseProvider
16-
from pydantic import PostgresDsn
1716
from sqlacodegen.generators import DeclarativeGenerator
1817
from sqlalchemy import MetaData, UniqueConstraint, text
1918
from sqlalchemy.sql import sqltypes
@@ -351,7 +350,7 @@ def make_table_generators(
351350

352351
settings = get_settings()
353352
engine = create_db_engine(
354-
settings.src_postgres_dsn, schema_name=settings.src_schema # type: ignore
353+
settings.src_dsn, schema_name=settings.src_schema # type: ignore
355354
)
356355

357356
tables: List[TableGenerator] = []
@@ -430,7 +429,7 @@ def _get_generator_for_vocabulary_table(
430429
)
431430

432431

433-
def make_tables_file(db_dsn: PostgresDsn, schema_name: Optional[str]) -> str:
432+
def make_tables_file(db_dsn: str, schema_name: Optional[str]) -> str:
434433
"""Write a file with the SQLAlchemy ORM classes.
435434
436435
Exists with an error if sqlacodegen is unsuccessful.
@@ -455,7 +454,7 @@ def make_tables_file(db_dsn: PostgresDsn, schema_name: Optional[str]) -> str:
455454

456455

457456
async def make_src_stats(
458-
dsn: PostgresDsn, config: dict, schema_name: Optional[str] = None
457+
dsn: str, config: dict, schema_name: Optional[str] = None
459458
) -> dict:
460459
"""Run the src-stats queries specified by the configuration.
461460

sqlsynthgen/remove.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@ def remove_db_data(orm_module: ModuleType, ssg_module: ModuleType) -> None:
1111
"""Truncate the synthetic data tables but not the vocabularies."""
1212
settings = get_settings()
1313

14-
assert settings.dst_postgres_dsn, "Missing destination database settings"
15-
dst_engine = create_db_engine(
16-
settings.dst_postgres_dsn, schema_name=settings.dst_schema
17-
)
14+
assert settings.dst_dsn, "Missing destination database settings"
15+
dst_engine = create_db_engine(settings.dst_dsn, schema_name=settings.dst_schema)
1816

1917
with dst_engine.connect() as dst_conn:
2018
for table in reversed(orm_module.Base.metadata.sorted_tables):
@@ -27,10 +25,8 @@ def remove_db_vocab(orm_module: ModuleType, ssg_module: ModuleType) -> None:
2725
"""Truncate the vocabulary tables."""
2826
settings = get_settings()
2927

30-
assert settings.dst_postgres_dsn, "Missing destination database settings"
31-
dst_engine = create_db_engine(
32-
settings.dst_postgres_dsn, schema_name=settings.dst_schema
33-
)
28+
assert settings.dst_dsn, "Missing destination database settings"
29+
dst_engine = create_db_engine(settings.dst_dsn, schema_name=settings.dst_schema)
3430

3531
with dst_engine.connect() as dst_conn:
3632
for table in reversed(orm_module.Base.metadata.sorted_tables):
@@ -43,10 +39,8 @@ def remove_db_tables(orm_module: ModuleType) -> None:
4339
"""Drop the tables in the destination schema."""
4440
settings = get_settings()
4541

46-
assert settings.dst_postgres_dsn, "Missing destination database settings"
47-
dst_engine = create_db_engine(
48-
settings.dst_postgres_dsn, schema_name=settings.dst_schema
49-
)
42+
assert settings.dst_dsn, "Missing destination database settings"
43+
dst_engine = create_db_engine(settings.dst_dsn, schema_name=settings.dst_schema)
5044

5145
metadata = orm_module.Base.metadata
5246
metadata.drop_all(dst_engine)

sqlsynthgen/settings.py

Lines changed: 30 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, Optional
2020

2121
# pylint: disable=no-self-argument
22-
from pydantic import BaseSettings, PostgresDsn, validator
22+
from pydantic import BaseSettings, validator
2323

2424

2525
class Settings(BaseSettings):
@@ -31,96 +31,40 @@ class Settings(BaseSettings):
3131
and synthetic values inserted.
3232
3333
Attributes:
34-
src_host_name (str):
35-
An element (host-name) of connection parameter
36-
src_port (int):
37-
Connection port eg. 5432
38-
src_user_name (str) :
39-
Connection username e.g. `postgres` or `myuser@mydb`
40-
src_password (str) :
41-
Connection password
42-
src_db_name (str) :
43-
Connection database e.g. "postgres"
44-
src_ssl_required (bool) :
45-
Flag `True` if db requires SSL
46-
47-
dst_host_name (str):
48-
Connection host-name to destination db
49-
dst_port (int) :
50-
Connection port eg. 5432
51-
dst_user_name (str) :
52-
Connection username e.g. `postgres` or `myuser@mydb`
53-
dst_password (str) :
54-
Connection password
55-
dst_db_name (str) :
56-
Connection database e.g. `postgres`
57-
dst_ssl_required (bool) :
58-
Flag `True` if db requires SSL
34+
src_dsn (str) :
35+
A DSN for connecting to the source database.
36+
37+
src_schema (str) :
38+
The source database schema to use, if applicable.
39+
40+
dst_dsn (str) :
41+
A DSN for connecting to the destination database.
42+
43+
dst_schema (str) :
44+
The destination database schema to use, if applicable.
5945
"""
6046

61-
# Connection parameters for the source PostgreSQL database. See also
47+
# Connection parameters for the source database. See also
6248
# https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS
63-
src_host_name: Optional[str] # e.g. "mydb.mydomain.com" or "0.0.0.0"
64-
src_port: int = 5432
65-
src_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb"
66-
src_password: Optional[str]
67-
src_db_name: Optional[str]
68-
src_ssl_required: bool = False # whether the db requires SSL
69-
src_schema: Optional[str]
49+
src_dsn: Optional[str]
50+
dst_dsn: Optional[str]
7051

71-
# Connection parameters for the destination PostgreSQL database.
72-
dst_host_name: Optional[
73-
str
74-
] # Connection parameter e.g. "mydb.mydomain.com" or "0.0.0.0"
75-
dst_port: int = 5432
76-
dst_user_name: Optional[str] # e.g. "postgres" or "myuser@mydb"
77-
dst_password: Optional[str]
78-
dst_db_name: Optional[str]
52+
src_schema: Optional[str]
7953
dst_schema: Optional[str]
80-
dst_ssl_required: bool = False # whether the db requires SSL
81-
82-
# These are calculated so do not provide them explicitly
83-
src_postgres_dsn: Optional[PostgresDsn]
84-
dst_postgres_dsn: Optional[PostgresDsn]
85-
86-
@validator("src_postgres_dsn", pre=True)
87-
def validate_src_postgres_dsn(
88-
cls, _: Optional[PostgresDsn], values: Any
89-
) -> Optional[str]:
90-
"""Create and validate the source db data source name."""
91-
return cls.check_postgres_dsn(_, values, "src")
92-
93-
@validator("dst_postgres_dsn", pre=True)
94-
def validate_dst_postgres_dsn(
95-
cls, _: Optional[PostgresDsn], values: Any
96-
) -> Optional[str]:
97-
"""Create and validate the destination db data source name."""
98-
return cls.check_postgres_dsn(_, values, "dst")
99-
100-
@staticmethod
101-
def check_postgres_dsn(
102-
_: Optional[PostgresDsn], values: Any, prefix: str
103-
) -> Optional[str]:
104-
"""Build a DSN string from the host, db name, port, username and password."""
105-
# We want to build the Data Source Name ourselves so none should be provided
106-
if _:
107-
raise ValueError("postgres_dsn should not be provided")
108-
109-
user = values[f"{prefix}_user_name"]
110-
password = values[f"{prefix}_password"]
111-
host = values[f"{prefix}_host_name"]
112-
port = values[f"{prefix}_port"]
113-
db_name = values[f"{prefix}_db_name"]
114-
115-
if user and password and host and port and db_name:
116-
dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
117-
118-
if values[f"{prefix}_ssl_required"]:
119-
return dsn + "?sslmode=require"
120-
121-
return dsn
122-
123-
return None
54+
55+
@validator("src_dsn")
56+
def validate_src_dsn(cls, dsn: Optional[str], values: Any) -> Optional[str]:
57+
"""Create and validate the source DB DSN."""
58+
if dsn and dsn.startswith("mariadb"):
59+
assert values.get("src_schema") is None
60+
return dsn
61+
62+
@validator("dst_dsn")
63+
def validate_dst_dsn(cls, dsn: Optional[str], values: Any) -> Optional[str]:
64+
"""Create and validate the destination DB DSN."""
65+
if dsn and dsn.startswith("mariadb"):
66+
assert values.get("dst_schema") is None
67+
return dsn
12468

12569
@dataclass
12670
class Config:

sqlsynthgen/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Optional
88

99
import yaml
10-
from pydantic import PostgresDsn
1110
from sqlalchemy import create_engine, event, select
1211
from sqlalchemy.ext.asyncio import create_async_engine
1312

@@ -53,7 +52,7 @@ def download_table(table: Any, engine: Any, yaml_file_name: str) -> None:
5352

5453

5554
def create_db_engine(
56-
postgres_dsn: PostgresDsn,
55+
postgres_dsn: str,
5756
schema_name: Optional[str] = None,
5857
use_asyncio: bool = False,
5958
**kwargs: dict,

tests/test_create.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_create_db_tables(
4444

4545
create_db_tables(mock_meta)
4646
mock_create_engine.assert_called_once_with(
47-
mock_get_settings.return_value.dst_postgres_dsn
47+
mock_get_settings.return_value.dst_dsn
4848
)
4949
mock_meta.create_all.assert_called_once_with(mock_create_engine.return_value)
5050

@@ -127,7 +127,7 @@ def test_create_db_vocab(
127127
mock_create_engine.return_value.connect.return_value.__enter__.return_value
128128
)
129129
mock_create_engine.assert_called_once_with(
130-
mock_get_settings.return_value.dst_postgres_dsn
130+
mock_get_settings.return_value.dst_dsn
131131
)
132132
# Running the same insert twice should be fine.
133133
create_db_vocab(vocab_list)

tests/test_main.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_make_tables_with_force_enabled(
274274
result: Result = runner.invoke(app, ["make-tables", force_option])
275275

276276
mock_make_tables.assert_called_once_with(
277-
test_settings.src_postgres_dsn, test_settings.src_schema
277+
test_settings.src_dsn, test_settings.src_schema
278278
)
279279
mock_path.return_value.write_text.assert_called_once_with(
280280
mock_tables_output, encoding="utf-8"
@@ -308,9 +308,7 @@ def test_make_stats(
308308
self.assertSuccess(result)
309309
with open(example_conf_path, "r", encoding="utf8") as f:
310310
config = yaml.safe_load(f)
311-
mock_make.assert_called_once_with(
312-
get_test_settings().src_postgres_dsn, config, None
313-
)
311+
mock_make.assert_called_once_with(get_test_settings().src_dsn, config, None)
314312
mock_path.return_value.write_text.assert_called_once_with(
315313
"a: 1\n", encoding="utf-8"
316314
)
@@ -391,7 +389,7 @@ def test_make_stats_with_force_enabled(
391389
)
392390

393391
mock_make.assert_called_once_with(
394-
test_settings.src_postgres_dsn, config_file_content, None
392+
test_settings.src_dsn, config_file_content, None
395393
)
396394
mock_path.return_value.write_text.assert_called_once_with(
397395
"some_stat: 0\n", encoding="utf-8"

0 commit comments

Comments
 (0)