Skip to content

Commit 33742df

Browse files
committed
fix: performance
1 parent 5b12e89 commit 33742df

File tree

1 file changed

+50
-34
lines changed

1 file changed

+50
-34
lines changed

src/query_farm_sql_scan_planning/planner.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,26 @@ def _evaluate_node_connector(
272272
273273
Returns True, False, or None if the expression cannot be evaluated.
274274
"""
275-
op_map: dict[
276-
type[sqlglot.expressions.Connector], Callable[[bool, bool], bool]
277-
] = {
278-
sqlglot.expressions.And: lambda left, right: left and right,
279-
sqlglot.expressions.Or: lambda left, right: left or right,
280-
sqlglot.expressions.Xor: lambda left, right: left ^ right,
281-
}
282-
283-
for expr_type, op in op_map.items():
284-
if isinstance(node, expr_type):
285-
left_result = self._evaluate_sql_node(node.left, file_info)
286-
right_result = self._evaluate_sql_node(node.right, file_info)
287-
if left_result is None or right_result is None:
288-
return None
289-
return op(left_result, right_result)
275+
assert isinstance(node, sqlglot.expressions.Connector), (
276+
f"Expected a connector node, got {node} of type {type(node)}"
277+
)
278+
match type(node):
279+
case sqlglot.expressions.And:
280+
return self._evaluate_sql_node(
281+
node.left, file_info
282+
) and self._evaluate_sql_node(node.right, file_info)
283+
case sqlglot.expressions.Or:
284+
return self._evaluate_sql_node(
285+
node.left, file_info
286+
) or self._evaluate_sql_node(node.right, file_info)
287+
case sqlglot.expressions.Xor:
288+
raise ValueError("Unsupported XOR operation in SQL expression.")
289+
# return self._evaluate_sql_node(
290+
# node.left, file_info
291+
# ) ^ self._evaluate_sql_node(node.right, file_info)
292+
case _:
293+
# If we reach here, it means the node is not a recognized connector type.
294+
assert False, f"Unexpected connector type: {type(node)}"
290295

291296
raise ValueError(f"Unsupported connector type: {type(node)}")
292297

@@ -439,26 +444,37 @@ def _evaluate_sql_node(
439444
Evaluate a SQL node against a file's field info.
440445
Returns True, False, or None if the expression cannot be evaluated.
441446
"""
442-
if isinstance(node, sqlglot.expressions.Connector):
443-
return self._evaluate_node_connector(node, file_info)
444-
elif isinstance(node, sqlglot.expressions.Predicate):
445-
return self._evaluate_node_predicate(node, file_info)
446-
elif isinstance(node, sqlglot.expressions.Not):
447-
if isinstance(node.this, sqlglot.expressions.In):
448-
# Handle 'not in' operations
449-
return self._evaluate_node_not_in(node.this, file_info)
450-
# Handle 'not' operations
451-
return not self._evaluate_sql_node(node.this, file_info)
452-
elif isinstance(node, sqlglot.expressions.Boolean):
453-
return node.to_py()
454-
elif isinstance(node, sqlglot.expressions.Case):
455-
return self._evaluate_node_case(node, file_info)
456-
elif isinstance(node, sqlglot.expressions.Null):
457-
return False
458-
else:
459-
raise ValueError(f"Unsupported node type: {type(node)}")
447+
match node:
448+
case sqlglot.expressions.Connector():
449+
return self._evaluate_node_connector(node, file_info)
450+
451+
case sqlglot.expressions.Predicate():
452+
return self._evaluate_node_predicate(node, file_info)
453+
454+
case sqlglot.expressions.Not():
455+
match node.this:
456+
case sqlglot.expressions.In():
457+
# Handle 'not in' operations
458+
return self._evaluate_node_not_in(node.this, file_info)
459+
case _:
460+
# Handle general 'not' operations
461+
inner_result = self._evaluate_sql_node(node.this, file_info)
462+
return None if inner_result is None else not inner_result
463+
464+
case sqlglot.expressions.Boolean():
465+
return node.to_py()
466+
467+
case sqlglot.expressions.Case():
468+
return self._evaluate_node_case(node, file_info)
469+
470+
case sqlglot.expressions.Null():
471+
return False
460472

461-
return False
473+
case _:
474+
raise ValueError(
475+
f"Unsupported node type: {type(node).__name__}. "
476+
f"Supported types: Connector, Predicate, Not, Boolean, Case, Null"
477+
)
462478

463479
def get_matching_files(self, exp: sqlglot.expressions.Expression | str) -> set[str]:
464480
"""

0 commit comments

Comments
 (0)