Skip to content

Commit a988687

Browse files
committed
Refactor SQL query handling in DbReplicator and MySQLApi for improved security and clarity
- Updated DbReplicator to pass raw primary key values to mysql_api, eliminating manual quote handling for parameterized queries. - Enhanced MySQLApi to use parameterized queries for pagination, preventing SQL injection and improving query safety. - Added detailed logging for query execution and parameters to aid in debugging and error handling.
1 parent 7632e2a commit a988687

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

mysql_ch_replicator/db_replicator_initial.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,9 @@ def perform_initial_replication_table(self, table_name):
167167

168168
while True:
169169

170+
# Pass raw primary key values to mysql_api - it will handle proper SQL parameterization
171+
# No need to manually add quotes - parameterized queries handle this safely
170172
query_start_values = max_primary_key
171-
if query_start_values is not None:
172-
for i in range(len(query_start_values)):
173-
key_type = primary_key_types[i]
174-
value = query_start_values[i]
175-
if 'int' not in key_type.lower():
176-
value = f"'{value}'"
177-
query_start_values[i] = value
178173

179174
records = self.replicator.mysql_api.get_records(
180175
table_name=table_name,

mysql_ch_replicator/mysql_api.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,14 @@ def get_records(
9292
order_by_str = ",".join(order_by_escaped)
9393

9494
where = ""
95+
query_params = []
96+
9597
if start_value is not None:
96-
# Build the start_value condition for pagination
97-
start_value_str = ",".join(map(str, start_value))
98-
where = f"WHERE ({order_by_str}) > ({start_value_str}) "
98+
# Build the start_value condition for pagination using parameterized query
99+
# This prevents SQL injection and handles special characters properly
100+
placeholders = ",".join(["%s"] * len(start_value))
101+
where = f"WHERE ({order_by_str}) > ({placeholders}) "
102+
query_params.extend(start_value)
99103

100104
# Add partitioning filter for parallel processing (e.g., sharded crawling)
101105
if (
@@ -116,10 +120,23 @@ def get_records(
116120
# Construct final query
117121
query = f"SELECT * FROM `{table_name}` {where}ORDER BY {order_by_str} LIMIT {limit}"
118122

123+
# Log query details for debugging
119124
logger.debug(f"Executing query: {query}")
125+
if query_params:
126+
logger.debug(f"Query parameters: {query_params}")
120127

121-
# Execute the query
122-
cursor.execute(query)
123-
res = cursor.fetchall()
124-
records = [x for x in res]
125-
return records
128+
# Execute the query with proper parameterization
129+
try:
130+
if query_params:
131+
cursor.execute(query, tuple(query_params))
132+
else:
133+
cursor.execute(query)
134+
res = cursor.fetchall()
135+
records = [x for x in res]
136+
return records
137+
except Exception as e:
138+
logger.error(f"Query execution failed: {query}")
139+
if query_params:
140+
logger.error(f"Query parameters: {query_params}")
141+
logger.error(f"Error details: {e}")
142+
raise

0 commit comments

Comments
 (0)