-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSQLParser.py
More file actions
1391 lines (1212 loc) · 64.7 KB
/
SQLParser.py
File metadata and controls
1391 lines (1212 loc) · 64.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import psycopg2
import sqlparse
import os
import re
import subprocess
import ast
import time
from collections.abc import Iterable
from multiprocessing import Process, Manager, Semaphore, Lock
from functions import execute_sql_view
from config import DBConfig
import logging
# from datetime import datetime, timezone, timedelta
logger = logging.getLogger('log')
root_dir = os.path.dirname(os.path.abspath(__file__))
parsing_time = 0
sql_keywords = [
'add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'autoincrement', 'between', 'boolean',
'by', 'call', 'case', 'cast', 'char', 'column', 'commit', 'constraint', 'create', 'cross',
'current_date', 'current_time', 'current_timestamp', 'database', 'date', 'default',
'delete', 'desc', 'distinct', 'drop', 'else', 'end', 'exists', 'extract', 'false',
'foreign', 'from', 'full', 'function', 'grant', 'group', 'having', 'if', 'in',
'inner', 'insert', 'int', 'integer', 'intersect', 'into', 'is', 'join', 'key',
'left', 'like', 'limit', 'not', 'null', 'on', 'or', 'order', 'outer', 'primary',
'procedure', 'rename', 'right', 'rollback', 'row', 'select', 'set', 'show',
'table', 'then', 'to', 'truncate', 'union', 'update', 'values', 'view', 'where',
'with', 'true', 'unique', 'alter', 'table', 'index', 'view', 'user', 'load',
'replace', 'insert', 'returning', 'group_concat', 'extract', 'recursive', 'isnull'
]
sql_functions = [
"sum", "count", "avg", "max", "min", "extract", "group_concat", "string_agg", "variance", "stddev",
"median", "percentile_cont", "percentile_disc", "abs", "ceiling", "ceil", "floor", "round", "power",
"exp", "extract", "log", "sqrt", "sin", "cos", "tan", "concat", "substring", "substr", "upper", "lower",
"trim", "length", "replace", "lpad", "rpad", "now", "date_add", "date_sub",
"date_part", "to_date", "to_char", "coalesce", "nullif", "case", "greatest", "least",
"json_extract", "json_agg", "xmlagg", "st_distance", "st_intersects", "st_union",
"row_number", "rank", "dense_rank", "lead", "lag", "ntile", "cast"
]
sql_constants = ["current_date", "current_time", "current_timestamp", "localtime",
"localtimestamp", "session_user", "current_user", "system_user", "user", "null"
]
identifier_pattern = r'\b[a-zA-Z_][a-zA-Z0-9_]*\b'
table_dot_column = r'\b[a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*\b'
constant_pattern = r"('(?:''|[^'])*')|(\b\d+\b)|(\b\d+\.\d+\b)|(\bTRUE\b|\bFALSE\b|\bNULL\b)"
def extract_number(key):
match = re.search(r'\d+', key)
return int(match.group()) if match else float('inf')
def execute_sql(sql, db_name, schema = 'public'):
conn = psycopg2.connect(
host = DBConfig.host,
user = DBConfig.user,
password = DBConfig.password,
port = DBConfig.port,
dbname = db_name
)
cursor = conn.cursor()
cursor.execute(f"set search_path to {schema};")
conn.commit()
cursor.close()
results = None
try:
cur = conn.cursor()
cur.execute("SET statement_timeout TO 360000;")
cur.execute(sql)
results = cur.fetchall()
# print(results)
except psycopg2.Error as e:
# logger.error(f"An error occurred in SQL {sql} : {e}")
results = []
finally:
cur.close()
conn.close()
return results
def get_table_names_from_db(db_name, schema = 'public'):
sql = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}';"
results = execute_sql(sql, db_name, schema)
if results is not None:
table_names = [res[0] for res in results]
else:
table_names = []
return table_names
def get_table_columns_from_db(db_name, schema = 'public'):
sql = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}';"
results = execute_sql(sql, db_name, schema)
if results is not None:
table_names = [res[0] for res in results]
else:
table_names = []
table_columns = {}
if table_names != [] :
for table in table_names :
sql = f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}';"
results = execute_sql(sql, db_name, schema)
if results is not None:
columns = [row[0] for row in results]
else :
columns = []
table_columns[table] = columns
return table_columns
def get_db_info(db_name, table_names, schema = 'public'):
# get information each table (including CREATE TABLE, FOREIGN KEY, PRIMARY KEY, etc.)
create_table_statements, primary_key_constraints, foreign_key_constraints, create_index_statements = [], [], [], []
for table_name in table_names:
os.environ['PGPASSWORD'] = DBConfig.password
command = 'pg_dump -s -h {} -p {} -U {} -d {} -t "{}.{}" --schema-only'.format(
DBConfig.host,
DBConfig.port,
DBConfig.user,
db_name,
schema,
table_name
)
# print(command)
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
output, error = process.communicate()
if error is None:
database_dump = output.decode()
filtered_lines = [line for line in database_dump.split("\n")
if not (line.strip() == "" or line.startswith("SET") or line.startswith("--") or line.startswith("SELECT") or line.startswith("pg_dump"))]
database_dump = "\n".join(filtered_lines)
statements = database_dump.split(";")
for stmt in statements:
if stmt.upper().startswith("CREATE TABLE"):
create_table_statements.append(stmt.strip().replace(f"{schema}.", "") + ";")
elif "PRIMARY KEY" in stmt.upper():
primary_key_constraints.append(stmt.strip().replace(f"{schema}.", "") + ";")
elif "FOREIGN KEY" in stmt.upper():
foreign_key_constraints.append(stmt.strip().replace(f"{schema}.", "") + ";")
else:
print(f"Error getting DDL for table {table_name}: {error}")
get_indexes_sql = "SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{}';".format(table_name)
for res in execute_sql(get_indexes_sql, db_name, schema):
create_index_statements.append(res[1].replace("public.", ""))
create_table_statements = [sqlparse.format(sql, keyword_case = 'upper', identifier_case = 'lower') for sql in create_table_statements]
primary_key_constraints = [sqlparse.format(sql, keyword_case = 'upper', identifier_case = 'lower') for sql in primary_key_constraints]
foreign_key_constraints = [sqlparse.format(sql, keyword_case = 'upper', identifier_case = 'lower') for sql in foreign_key_constraints]
create_index_statements = [sqlparse.format(sql, keyword_case = 'upper', identifier_case = 'lower') for sql in create_index_statements]
return create_table_statements, primary_key_constraints, foreign_key_constraints, create_index_statements
def get_db_schema(db_name, path, schema = 'public'):
table_names = get_table_names_from_db(db_name, schema)
create_table_statements, primary_key_statements, foreign_key_statements, create_index_statements = get_db_info(db_name, table_names, schema)
db_schema = dict()
data_types = []
table_info_list = []
for table_name in table_names:
table_info = dict()
table_info["table"] = table_name
rows = execute_sql(f"SELECT COUNT(*) FROM {table_name};", db_name, schema)[0][0]
table_info["rows"] = rows
table_info["columns"] = []
columns_and_types = execute_sql(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';", db_name, schema)
for (column, data_type) in columns_and_types:
data_type = data_type.upper()
table_info["columns"].append(
{
"name": column,
"type": data_type
}
)
data_types.append(data_type)
table_ddl = None
for ddl in create_table_statements:
if "CREATE TABLE {} (".format(table_name).lower() in ddl.lower():
table_ddl = ddl
break
if table_ddl is None:
logger.warning(f"Can not find ddl for table {table_name}")
# raise ValueError("can not find ddl for table {}".format(table_name))
table_info["ddl"] = table_ddl
table_info_list.append(table_info)
db_schema["table_info"] = table_info_list
db_schema["primary_key_statements"] = primary_key_statements
db_schema["foreign_key_statements"] = foreign_key_statements
db_schema["create_index_statements"] = create_index_statements
# f"./data/{db_name}_schema.json"
abs_path = os.path.abspath(path)
# print(abs_path)
if not os.path.exists(os.path.dirname(abs_path)) : os.makedirs(os.path.dirname(abs_path))
with open(path, "w", encoding="utf-8") as f:
f.write(json.dumps(db_schema, indent=2, ensure_ascii=False))
def schema_pruning(db_name, sql, path):
with open(path, 'r') as file :
schema = json.load(file)
sql = normalize_sql(sql)
sql = sqlparse.format(sql, keyword_case = 'upper', identifier_case = 'lower')
sql_tokens = sql.split()
sql_tokens = [token.replace('\"', '').lower() for token in sql_tokens]
# print(sql_tokens)
used_table_names = []
used_column_info = []
used_table_ddls = []
used_column_info_rows = []
for table in schema["table_info"]:
if table["table"] in sql_tokens:
used_table_names.append(table["table"])
used_table_ddls.append(table["ddl"])
for column in table["columns"]:
if any(column["name"] in sql_token for sql_token in sql_tokens):
used_column_info.append(table["table"] + "." + column["name"] + ": " + column["type"])
used_column_info_rows.append((table["table"] + "." + column["name"] + ": " + column["type"], table['rows']))
used_pk_stmts, used_fk_stmts, used_index_stmts = [], [], []
for primary_key_statement in schema["primary_key_statements"]:
for table_name in used_table_names:
if "ALTER TABLE ONLY {}\n".format(table_name).lower() in primary_key_statement.lower() and primary_key_statement not in used_pk_stmts:
used_pk_stmts.append(primary_key_statement)
for foreign_key_statement in schema["foreign_key_statements"]:
fk_source_table_name = foreign_key_statement.replace("ALTER TABLE ONLY ", "").split("\n")[0].strip()
fk_target_table_name = foreign_key_statement.split("REFERENCES")[1].split("(")[0].strip()
if fk_source_table_name.lower() in used_table_names and fk_target_table_name.lower() in used_table_names and foreign_key_statement not in used_fk_stmts:
used_fk_stmts.append(foreign_key_statement)
for create_index_statement in schema["create_index_statements"]:
for table_name in used_table_names:
if "ON {} USING".format(table_name).lower() in create_index_statement.lower() and create_index_statement not in used_index_stmts:
used_index_stmts.append(create_index_statement)
# execution_results = execute_sql("EXPLAIN {}".format(sql), db_name)
# query_plan = "\n".join([row[0] for row in execution_results])
# print(type(query_plan))
pk = "\n".join(used_pk_stmts) if len(used_pk_stmts) != 0 else "None"
fk = "\n".join(used_fk_stmts) if len(used_fk_stmts) != 0 else "None"
indexes = "\n".join(used_index_stmts) if len(used_index_stmts) != 0 else "None"
return used_column_info, used_column_info_rows, used_table_ddls, pk, fk, indexes
def get_tables_columns_names(schema_path) :
tables_columns = {}
with open(schema_path, 'r') as f :
schema = json.load(f)
for table in schema['table_info'] :
tables_columns[table['table']] = [col_info['name'] for col_info in table['columns']]
return tables_columns
def get_tables_names(schema_path) :
tables = []
with open(schema_path, 'r') as f :
schema = json.load(f)
for table in schema['table_info'] :
tables.append(table["table"])
return tables
def find_view_info(db_name, schema = 'public') :
views = f"""
SELECT viewname
FROM pg_views
WHERE schemaname = '{schema}';
"""
views = execute_sql(views, db_name, schema)
views_columns = []
views2columns = {}
for view in views :
view_cols = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{view[0]}';
"""
views_columns_ = [row[0] for row in execute_sql(view_cols, db_name, schema)]
views_columns += views_columns_
views2columns[view[0]] = views_columns_
return views2columns, views_columns
## ndv of all tables
def get_ndvs_all(conn, ndv_path, schema = 'public') :
ndvs = {}
try :
cur = conn.cursor()
cur.execute("analyze;")
# get tables
sql = f"""SELECT tablename
FROM pg_tables
WHERE schemaname = '{schema}';
"""
cur.execute(sql)
tables = cur.fetchall()
# print(tables)
for table in tables :
table = table[0]
sql = f"""
WITH total_rows AS (
SELECT COUNT(*) AS total FROM {table}
)
SELECT
attname,
CASE
WHEN n_distinct >= 0 THEN
(n_distinct / NULLIF(total_rows.total, 0)) -- 避免除以0
ELSE
(n_distinct * -1) -- 保留负数部分,表示百分比
END AS n_distinct_percentage
FROM
pg_stats,
total_rows
WHERE
tablename = '{table}'
AND
schemaname = '{schema}';
"""
cur.execute(sql)
res = cur.fetchall()
for r in res :
ndvs[f"{table}.{r[0]}"] = r[1]
sql = f"select count(*) from {schema}.{table};"
cur.execute(sql)
total_rows = cur.fetchone()[0]
if total_rows == 0 :
sql = f"select column_name from information_schema.columns where table_name = '{table}';"
cur.execute(sql)
results = cur.fetchall()
for row in results :
ndvs[f"{table}.{row[0]}"] = 0
except Exception as e :
print(f"Error: {e} --> {sql}")
exit()
finally :
if cur :
cur.close()
abs_path = os.path.abspath(ndv_path)
if not os.path.exists(os.path.dirname(abs_path)) : os.makedirs(os.path.dirname(abs_path))
with open(ndv_path, 'w') as f:
json.dump(ndvs, f, indent=4)
return ndvs
## selectivity
def normalize_tokenlist(tokenlist) :
normalized_tl = ""
for token in tokenlist :
if token.ttype is not None :
if '-' in token.value :
if token.ttype == sqlparse.tokens.Token.Operator :
tmp_values = token.value.split('-')
normalized_tl += ' - '.join(tmp for tmp in tmp_values)
elif token.ttype == sqlparse.tokens.Token.Literal.String.Single :
normalized_tl += token.value + " "
else : normalized_tl += token.value + " "
else :
if any(t is not None for t in token) :
normalized_tl += normalize_tokenlist(token) + " "
else :
normalized_tl += token.value + " "
return normalized_tl
def normalize_sql(sql) :
parsed_sql = sqlparse.parse(sql)
normalized_sql = ""
for psql in parsed_sql :
for token in psql :
# print(token, token.__class__, token.ttype)
if token.ttype is not None :
normalized_sql += token.value + " "
else :
normalized_sql += normalize_tokenlist(token) + " "
# .replace(" - ", "-") --> there might be error in the subtraction operation
normalized_sql = re.sub(r'\s+', ' ', normalized_sql).strip().replace(" . ", ".")
return normalized_sql
def is_subquery(token_list):
if isinstance(token_list, sqlparse.sql.TokenList):
for token in token_list.tokens:
if token.ttype is sqlparse.tokens.DML and token.value.upper() == 'SELECT':
return True
return False
def extract_subqueries(token_list):
subqueries = []
if is_subquery(token_list):
subqueries.append(token_list)
if isinstance(token_list, Iterable) :
for token in token_list.tokens:
if isinstance(token, sqlparse.sql.TokenList):
subqueries.extend(extract_subqueries(token))
return subqueries
def is_literal_list(string):
try:
result = ast.literal_eval(string)
return isinstance(result, list)
except (ValueError, SyntaxError):
if ',' in string and all(sql_keyword not in string for sql_keyword in sql_keywords) and all(sql_function not in string for sql_function in sql_functions) :
return True
else :
return False
def extract_alias(sql, conn, tables_columns) :
alias = {} ## {alias_name: [(table, column)]}
simple_alias = {}
complex_alias = {}
## get table names
table_names = list(tables_columns.keys())
# print(table_names)
column_names = []
for columns in tables_columns :
column_names.extend(columns)
parsed_sql = sqlparse.parse(normalize_sql(sql))
if len(parsed_sql) == 1 :
parsed_sql = parsed_sql[0]
aggregation = ""
with_subquery = False
keep_identifier = False
## extract alias pairs from sql [table, columns, expressions, subquery]
for token in parsed_sql.tokens :
# print(token, token.__class__, token.ttype)
if isinstance(token, sqlparse.sql.IdentifierList) :
if with_subquery :
for identifier in token.get_identifiers() :
if " as " in identifier.value.lower() :
if " as " in identifier.value :
al = token.value.split(" as ")[0]
rl = token.value.replace(al + " as ", '').strip()
else :
al = token.value.split(" AS ")[0]
rl = token.value.replace(al + " AS ", '').strip()
pair = (rl, al)
# extract_alias for subquery
subsql = rl[1:-1].strip()
s_sa, c_sa = extract_alias(subsql, conn, tables_columns)
simple_alias.update(s_sa)
complex_alias.update(c_sa)
if pair != [] and pair[1].lower() not in sql_keywords: alias[pair[1]] = pair[0]
with_subquery = False
else :
for identifier in token.get_identifiers() :
# print(identifier, identifier.__class__, identifier.ttype)
if isinstance(identifier, sqlparse.sql.Identifier) :
if identifier.get_alias() : # table/column as alias | table/column alias
al = identifier.value.split()[-1]
rl = ' '.join(identifier.value.split()[:-1])
if rl[-2:].lower() == 'as' : rl = rl[:-2].strip()
elif identifier.value in table_names or identifier.value in column_names :
keep_identifier = True
rl = identifier.value
continue
else : continue
pair = (rl, al)
if rl.startswith('(') and rl.endswith(')') and 'select' in rl.lower().split() : # subquery
subsql = rl[1:-1].strip()
s_sa, c_sa = extract_alias(subsql, conn, tables_columns)
simple_alias.update(s_sa)
complex_alias.update(c_sa)
if aggregation != "" :
pair = (aggregation + pair[0], pair[1])
aggregation = ""
if pair != [] and pair[1].lower() not in sql_keywords: # case that desc as pair[0]
alias[pair[1]] = pair[0]
elif isinstance(identifier, sqlparse.sql.Token) : # function
if keep_identifier and identifier.value.strip() != "" :
if parsed_sql.tokens.index(token) + 1 > len(parsed_sql.tokens) or (parsed_sql.tokens.index(token) + 2 < len(parsed_sql.tokens) and parsed_sql.tokens[parsed_sql.tokens.index(token) + 1].value != '(' and parsed_sql.tokens[parsed_sql.tokens.index(token) + 2].value.lower() in sql_keywords) :
al = identifier.value
pair = (rl, al)
if aggregation != "" :
pair = (aggregation + pair[0], pair[1])
aggregation = ""
if pair != [] and pair[1].lower() not in sql_keywords: # case that desc as pair[0]
alias[pair[1]] = pair[0]
elif identifier.value.lower() in sql_functions :
aggregation = identifier.value
keep_identifier = False
elif identifier.value.lower() in sql_functions :
aggregation = identifier.value
elif isinstance(token, sqlparse.sql.Identifier) :
if with_subquery : # with (subquery) as alias
if " as " in token.value.lower() :
if " as " in token.value :
al = token.value.split(" as ")[0]
rl = token.value.replace(al + " as ", '').strip()
else :
al = token.value.split(" AS ")[0]
rl = token.value.replace(al + " AS ", '').strip()
with_subquery = False
elif token.get_alias() : # table/column as alias | table/column alias
al = token.value.split()[-1]
rl = ' '.join(token.value.split()[:-1])
if rl[-2:].lower() == 'as' : rl = rl[:-2].strip()
elif token.value in table_names or token.value in column_names :
keep_identifier = True
rl = token.value
continue
else : continue
pair = (rl, al)
if rl.startswith('(') and rl.endswith(')') and 'select' in rl.lower().split() : # subquery
subsql = rl[1:-1].strip()
s_sa, c_sa = extract_alias(subsql, conn, tables_columns)
simple_alias.update(s_sa)
complex_alias.update(c_sa)
if aggregation != "" :
pair = (aggregation + pair[0], pair[1])
aggregation = ""
if pair != [] and pair[1].lower() not in sql_keywords: # case that desc as pair[0]
alias[pair[1]] = pair[0]
elif isinstance(token, sqlparse.sql.Token) :
if token.value.lower() == "with" : # with subquery
with_subquery = True
elif token.value.lower() in sql_functions :
aggregation = token.value
elif aggregation != "" and token.value.lower() in sql_keywords and token.value.lower() not in sql_functions:
aggregation = ""
elif isinstance(token, sqlparse.sql.Where) : # extract_alias for subquery in where
if '(' in token.value and ')' in token.value :
for t in token :
if t.value.startswith('(') and t.value.endswith(')') and 'select' in t.value.lower().split() : # subquery
subsql = t.value[1:-1].strip()
s_sa, c_sa = extract_alias(subsql, conn, tables_columns)
simple_alias.update(s_sa)
complex_alias.update(c_sa)
else :
if keep_identifier and token.value.strip() != '' :
if parsed_sql.tokens.index(token) + 1 > len(parsed_sql.tokens) or (parsed_sql.tokens.index(token) + 2 < len(parsed_sql.tokens) and parsed_sql.tokens[parsed_sql.tokens.index(token) + 2].ttype is sqlparse.tokens.Keyword ) :
al = token.value
pair = (rl, al)
if aggregation != "" :
pair = (aggregation + pair[0], pair[1])
aggregation = ""
if pair != [] and pair[1].lower() not in sql_keywords: # case that desc as pair[0]
alias[pair[1]] = pair[0]
keep_identifier = False
## get single table or column alias [alias_name, real_name]
single_alias = {}
for alias_ in alias.items() :
if re.fullmatch(identifier_pattern, alias_[1]) : # table or column
single_alias[alias_[0]] = alias_[1]
## replace alias in table.column [without considering subquery or expression replacement][table + column / expression + subquery]
for alias_ in alias.items() :
if alias_[0] not in single_alias.keys() :
if re.fullmatch(table_dot_column, alias_[1]) : # table.column format
table = alias_[1].split(".")[0]
column = alias_[1].split(".")[1]
if table in single_alias.keys() : table = single_alias[table]
alias[alias_[0]] = (table, column)
else : # other format with expression or subquery
tokens = alias_[1].split()
expression = ""
for token in tokens :
t = ''
if '.' in token : # table.column
table_name = token.split('.')[0]
if table_name in single_alias.keys() :
table_name = single_alias[table_name]
col_name = token.split('.')[1]
if col_name in single_alias.keys() :
col_name = single_alias[col_name]
t = f'{table_name}.{col_name}'
else : # table/column
if token in single_alias.keys() :
t = single_alias[token]
else : t = token
expression += t + " "
alias[alias_[0]] = expression
else : # table or column
if alias_[1] in table_names : alias[alias_[0]] = (alias_[1],)
else : # column name
find_column = False
for table in table_names :
columns = tables_columns[table]
if alias_[1] in columns :
alias[alias_[0]] = (table, alias_[1])
find_column = True
break
if not find_column :
logger.debug(f"Warning [{alias_}] : Name {alias_[1]} not found in the tables")
alias[alias_[0]] = ('', alias_[1])
if type(alias[alias_[0]]) is tuple :
simple_alias[alias_[0]] = alias[alias_[0]]
else :
complex_alias[alias_[0]] = alias[alias_[0]]
else :
print("Error: SQL statement has more than one statement")
exit()
return simple_alias, complex_alias
def judge_identifier(tokens, db_name, table_columns, alias, schema = 'public') :
## get table names
table_names = list(table_columns.keys())
# print(table_names)
column_names = []
for columns in table_columns.values() :
column_names.extend(columns)
view_info, view_cols = find_view_info(db_name, schema)
view_names = list(view_info.keys())
all_identifiers = column_names + table_names + list(alias.keys()) + view_cols + view_names
# res --> all_identifiers
for token in tokens.value.split() :
res = re.fullmatch(identifier_pattern, token)
if res :
if res[0] in all_identifiers : return True
else :
if re.fullmatch(table_dot_column, token) : return True
if "( * )" in tokens.value : return True
return False
def get_predicate_str(tokens, aggregation_func, simple_alias, db_name, used_tables, table_columns, alias, schema = 'public') :
operator = ''
predicate = ''
where_condition = True
cnt = 0
## left, operator, right
for t in tokens :
if t.ttype == sqlparse.tokens.Token.Operator.Comparison or (t.ttype == sqlparse.tokens.Token.Keyword and t.value.lower() in ["like", "in", "exists", "between"]):
operator = t
break
# print(tokens.left, tokens.right)
if 'select' in tokens.left.value.lower() or 'select' in tokens.right.value.lower() : return "", False
if judge_identifier(tokens.left, db_name, table_columns, alias, schema) : cnt += 1
if judge_identifier(tokens.right, db_name, table_columns, alias, schema): cnt += 1
if cnt == 2 : where_condition = False # other predicates
elif cnt == 1 :where_condition = True
else :
return tokens.value, False
left = rewrite_expression(tokens.left, simple_alias, used_tables, table_columns)
right = rewrite_expression(tokens.right, simple_alias, used_tables, table_columns)
## left, operator, right [aggregation]
predicate = aggregation_func + left + " " + operator.value + " " + right
return predicate, where_condition
def get_preidcate(original_predicate, simple_alias, used_tables, table_columns) :
rewrite_predicate = ""
variables = original_predicate.split()
for variable in variables :
tmp = variable
if '.' in variable : # table.column
tab = variable.split('.')[0]
col = variable.split('.')[1]
if col in simple_alias.keys() :
if len(simple_alias[col]) == 2 : tmp = f"{simple_alias[col][0]}.{simple_alias[col][1]}"
else :
if tab in simple_alias.keys() :
tmp = f'{simple_alias[tab][0]}.{col}'
elif tab in simple_alias.keys() :
tmp = f'{simple_alias[tab][0]}.{col}'
else :
if variable in simple_alias.keys() :
if len(simple_alias[variable]) == 1 : tmp = simple_alias[variable][0]
else :
tmp = simple_alias[variable][0] + '.' + simple_alias[variable][1]
else : # column name
find = False
for table in used_tables :
if variable in table_columns[table] :
tmp = table + '.' + variable
find = True
break
if not find :
for table in table_columns.keys() :
if variable in table_columns[table] :
tmp = table + '.' + variable
find = True
break
rewrite_predicate += tmp + " "
return rewrite_predicate
def rewrite_expression(tokenlist, simple_alias, used_tables, table_columns) :
original_predicate = tokenlist.value
rewrite_predicate = get_preidcate(original_predicate, simple_alias, used_tables, table_columns)
return rewrite_predicate
def find_table_info(db_name, schema_path, predicate, schema = 'public') :
tables = []
table_columns = get_tables_columns_names(schema_path)
table_names = list(table_columns.keys())
# print(schema) # table --> columns --> name
items = predicate.split(' ')
for item in items :
if re.fullmatch(table_dot_column, item) : # table.column
table = item.split('.')[0]
if table in table_names and table not in tables : tables.append(table)
elif re.fullmatch(identifier_pattern, item) : # column
find = False
for table, cols in table_columns.items() :
for col in cols :
if col == item and table not in tables:
tables.append(table)
find = True
break
if not find :
view2columns, _ = find_view_info(db_name, schema)
for v2c in view2columns.items():
if item in v2c[1] :
tables.append(v2c[0])
find = True
break
return tables
def parse_sql(sql, conn, db_name, schema_path, table_columns, schema = 'public') :
parsed_sql = sqlparse.parse(normalize_sql(sql))
where_predicates = []
multi_where_predicates = []
other_predicates = []
group_order_columns = []
aggregation_func = ""
aggregation_bool = False
multi_where_predicate = ""
is_and = False
is_or = False
is_from = False
is_between = False
group_order = False
column_names = []
for columns in table_columns.values() :
column_names.extend(columns)
# print(column_names)
used_tables = []
## get alias from sql ##
simple_alias, complex_alias = extract_alias(sql, conn, table_columns)
alias = {**simple_alias, **complex_alias}
# print(simple_alias, alias)
## parse sql ##
if len(parsed_sql) == 1 :
parsed_sql = parsed_sql[0]
if 'intersect' in parsed_sql.value.lower().split() :
if 'intersect' in parsed_sql.value : sqls = parsed_sql.value.split(' intersect ')
else : sqls = parsed_sql.value.split(' INTERSECT ')
for sql in sqls :
info_, _, _ = parse_sql(sql, conn, db_name, schema_path, table_columns, schema)
where_predicates.extend(info_['where_predicates'])
other_predicates.extend(info_['other_predicates'])
multi_where_predicates.extend(info_['multi_where_predicates'])
group_order_columns.extend(info_['group_order_columns'])
else :
## extract predicates
for token in parsed_sql.tokens :
# print(token, token.__class__, token.ttype)
predicate = ""
subquery = extract_subqueries(token)
if subquery != [] : # parse sql for subquery
for subq in subquery :
info_, sa_, ca_ = parse_sql(subq.value[1:-1].strip(), conn, db_name, schema_path, table_columns, schema)
where_predicates.extend(info_['where_predicates'])
other_predicates.extend(info_['other_predicates'])
multi_where_predicates.extend(info_['multi_where_predicates'])
group_order_columns.extend(info_['group_order_columns'])
simple_alias.update(sa_)
complex_alias.update(ca_)
alias = {**simple_alias, **complex_alias}
# print(token.value)
# continue
if token.ttype is sqlparse.tokens.Keyword : # functions or group/order by
if token.value.lower() in sql_functions :
aggregation_func = token.value
aggregation_bool = True
elif token.value.lower() in ['group by', 'order by'] :
group_order = True
elif token.value.lower() == 'from' :
is_from = True
continue
elif isinstance(token, sqlparse.sql.Comparison) : # join condition / filter subquery
predicate, where = get_predicate_str(token, aggregation_func, simple_alias, db_name, used_tables, table_columns, alias, schema)
if aggregation_func != "" : aggregation_func = ""
if predicate != "" :
if where : where_predicates.append(predicate)
else : other_predicates.append(predicate)
elif isinstance(token, sqlparse.sql.Where) : # where condition
aggregation_func = ""
where_predicate_str = ""
predicate = ""
for t in token :
# print(t, t.__class__, t.ttype, where_predicate_str)
if t.ttype is sqlparse.tokens.Keyword and t.value.lower() in sql_functions:
aggregation_func = t.value
continue
elif isinstance(t, sqlparse.sql.Comparison) :
predicate, where_condition = get_predicate_str(t, aggregation_func, simple_alias, db_name, used_tables, table_columns, alias, schema)
if aggregation_func != "" : aggregation_func = ""
elif isinstance(t, sqlparse.sql.Token) :
if t.value.lower() == 'where' :
continue
elif t.value.lower() in ["and", "or"] or t == token[-1] :
# print(f"-- {t.value} --")
if t.value.lower() in ["and", "or"] and ('between' not in where_predicate_str.lower() or is_between):
if t.value.lower() == "and" : is_and = True
elif t.value.lower() == "or" : is_or = True
else :
if 'between' in where_predicate_str.lower() :
is_between = True
if not (t.value.lower() == "and" and 'and' in where_predicate_str.lower()) : where_predicate_str += t.value
if t != token[-1] : continue
if t == token[-1] and t.value == ';' : where_predicate_str = where_predicate_str
if where_predicate_str.strip()!= "" :
if where_predicate_str.lower().strip().startswith('and' or 'or') :
where_predicate_str = where_predicate_str[3:]
where_predicate_str = get_preidcate(where_predicate_str, simple_alias, used_tables, table_columns)
if where_predicate_str != "" :
# print(where_predicate_str)
where_predicates.append(where_predicate_str)
if multi_where_predicate == "" :
multi_where_predicate = where_predicate_str # start of current multi predicates
elif is_and and where_predicate_str not in multi_where_predicate : multi_where_predicate += ' and ' + where_predicate_str
elif is_or and where_predicate_str not in multi_where_predicate : multi_where_predicate += ' or ' + where_predicate_str
is_or = False
is_and = False
if 'between' in where_predicate_str.lower() : is_between = False
where_predicate_str = ""
continue
elif isinstance(t, sqlparse.sql.Parenthesis) :
if t.value.strip().startswith('(') and t.value.strip().endswith(')') and 'select' not in t.value.lower() :
if aggregation_func == "" : predicates_ = t.value[1:-1].strip()
else : predicates_ = aggregation_func + t.value
if is_literal_list('[' + predicates_ + ']') or any(i not in predicates_.lower() for i in ['and', 'or', 'select']) :
if is_literal_list('[' + predicates_ + ']') : where_predicate_str += t.value
else : where_predicate_str += predicates_
else :
fake_sql = f"SELECT * FROM table WHERE {predicates_}"
sub_info, _, _ = parse_sql(fake_sql, conn, db_name, schema_path, table_columns, schema)
where_predicates = list(set(where_predicates).union(set(sub_info['where_predicates'])))
other_predicates = list(set(other_predicates).union(set(sub_info['other_predicates'])))
multi_where_predicates = list(set(multi_where_predicates).union(set(sub_info['multi_where_predicates'])))
group_order_columns = list(set(group_order_columns).union(set(sub_info['group_order_columns'])))
aggregation_func = ""
elif t.value.strip().startswith('(') and t.value.strip().endswith(')') and 'select' in t.value.lower() :
where_predicate_str = ""
continue
else :
where_predicate_str += t.value
else :
if where_predicate_str.strip() == "" and t.value.lower() in ['and', 'or'] : continue
elif isinstance(t, sqlparse.sql.Operation) :
if aggregation_func != "" :
where_predicate_str += aggregation_func
aggregation_func = ""
where_predicate_str += t.value
continue
elif isinstance(t, sqlparse.sql.Identifier) and ')' in t.value :
where_predicate_str += t.value.split(')')[0]
continue
where_predicate_str += t.value
if predicate != "" :
if where_condition :
where_predicates.append(predicate)
if multi_where_predicate == "" :
multi_where_predicate = predicate # start of current multi predicates
continue
elif is_and and predicate not in multi_where_predicate : multi_where_predicate += ' and ' + predicate
elif is_or and predicate not in multi_where_predicate : multi_where_predicate += ' or ' + predicate
is_or = False
is_and = False
else : other_predicates.append(predicate)
predicate = ""
elif isinstance(token, sqlparse.sql.Identifier) :
if is_from :
for table in table_columns.keys() :
if table in token.value : used_tables.append(table)
elif group_order : # add colunmn name to group_order_cols
# print(token.value)
tv = token.value
if 'desc' in token.value.lower() :
tv = tv.replace('desc', ' ').replace('DESC', ' ').strip()
elif 'asc' in token.value.lower() :
tv = tv.replace('asc', ' ').replace('ASC', ' ').strip()
original_value = ""
if tv in column_names :
original_value = tv
elif len(token.value.strip().split()) == 1 and '.' in token.value :
original_value = token.value.split('.')[-1]
else : # single column name with alias
if tv in alias.keys() :
temp = alias[tv]
if isinstance(temp, tuple) :
if len(temp) == 2 : original_value = temp[1]
elif isinstance(temp, str) : # find all column_name in this expression
for cn in column_names :
if cn in temp : original_value = cn
else :
logger.debug(f"** Error type of {temp}, {type(temp)}")
# else :
# logger.error(f"** Error alias of {tv}")
if original_value != "" :
for table_name, cols in table_columns.items() :
if original_value in cols :
group_order_columns.append(f"{table_name}.{original_value}")
group_order = False
elif token.value.lower().startswith('case when ') :
tmp_predicate = token.value.lower().split('case when ')[-1].split('then')[0].strip()
fake_sql = f"SELECT * FROM table WHERE {tmp_predicate} ;"
# print(fake_sql)
sub_info, _, _ = parse_sql(fake_sql, conn, db_name, schema_path, table_columns, schema)
where_predicates = list(set(where_predicates).union(set(sub_info['where_predicates'])))
other_predicates = list(set(other_predicates).union(set(sub_info['other_predicates'])))
multi_where_predicates = list(set(multi_where_predicates).union(set(sub_info['multi_where_predicates'])))
group_order_columns = list(set(group_order_columns).union(set(sub_info['group_order_columns'])))
elif isinstance(token, sqlparse.sql.IdentifierList) :
if is_from :
for table in table_columns.keys() :
if table in token.value : used_tables.append(table)
if group_order : # add colunmn name to group_order_cols
for t in token :
if re.fullmatch(identifier_pattern, t.value) or re.fullmatch(table_dot_column, t.value) :
original_value = ""
if t.value.lower() in sql_functions :
aggregation_func = t.value
continue
elif t.value in column_names :
original_value = t.value
elif len(t.value.strip().split()) == 1 and '.' in t.value :
original_value = t.value.split('.')[-1]
else : # single column name with alias
tv = aggregation_func + " " + t.value
tv = tv.strip()
if tv in alias.keys() :
temp = alias[tv]
if isinstance(temp, tuple) :