Skip to content

Commit 5b12e89

Browse files
committed
fix: performance fixes
1 parent bae5b35 commit 5b12e89

File tree

1 file changed

+41
-39
lines changed

1 file changed

+41
-39
lines changed

src/query_farm_sql_scan_planning/planner.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class SetFieldInfo(BaseFieldInfo):
4747
def _scalar_value_op(
4848
a: pa.Scalar, b: pa.Scalar, op: Callable[[Any, Any], bool]
4949
) -> bool:
50+
"""
51+
Perform a scalar value operation on two scalars.
52+
"""
5053
assert not pa.types.is_null(a.type), (
5154
f"Expected a non-null scalar value, got {a} of type {a.type}"
5255
)
@@ -77,23 +80,23 @@ def _scalar_value_op(
7780
return op(a.as_py(), b.as_py())
7881

7982

80-
def _scalar_value_lte(a: pa.Scalar, b: pa.Scalar) -> bool:
83+
def _sv_lte(a: pa.Scalar, b: pa.Scalar) -> bool:
8184
return _scalar_value_op(a, b, lambda x, y: x <= y)
8285

8386

84-
def _scalar_value_lt(a: pa.Scalar, b: pa.Scalar) -> bool:
87+
def _sv_lt(a: pa.Scalar, b: pa.Scalar) -> bool:
8588
return _scalar_value_op(a, b, lambda x, y: x < y)
8689

8790

88-
def _scalar_value_gt(a: pa.Scalar, b: pa.Scalar) -> bool:
91+
def _sv_gt(a: pa.Scalar, b: pa.Scalar) -> bool:
8992
return _scalar_value_op(a, b, lambda x, y: x > y)
9093

9194

92-
def _scalar_value_gte(a: pa.Scalar, b: pa.Scalar) -> bool:
95+
def _sv_gte(a: pa.Scalar, b: pa.Scalar) -> bool:
9396
return _scalar_value_op(a, b, lambda x, y: x >= y)
9497

9598

96-
def _scalar_value_eq(a: pa.Scalar, b: pa.Scalar) -> bool:
99+
def _sv_eq(a: pa.Scalar, b: pa.Scalar) -> bool:
97100
return _scalar_value_op(a, b, lambda x, y: x == y)
98101

99102

@@ -115,6 +118,7 @@ def __init__(self, files: list[tuple[str, FileFieldInfo]]):
115118
file_ranges: List of tuples containing (filename, min_val, max_val)
116119
"""
117120
self.files = files
121+
self.connection = duckdb.connect(":memory:")
118122

119123
def _eval_predicate(
120124
self,
@@ -146,22 +150,20 @@ def _eval_predicate(
146150

147151
# The thing on the right side should be something that can be evaluated against a range.
148152
# ideally, its going to be a
149-
if True: # isinstance(node.right, sqlglot.expressions.Cast):
150-
connection = duckdb.connect(":memory:")
151-
value_result = connection.execute(
152-
f"select {node.right.sql('duckdb')}"
153-
).arrow()
154-
assert value_result.num_rows == 1, (
155-
f"Expected a single row result from cast, got {value_result.num_rows} rows"
156-
)
157-
assert value_result.num_columns == 1, (
158-
f"Expected a single column result from cast, got {value_result.num_columns} columns"
159-
)
153+
value_result = self.connection.execute(
154+
f"select {node.right.sql('duckdb')}"
155+
).arrow()
156+
assert value_result.num_rows == 1, (
157+
f"Expected a single row result from cast, got {value_result.num_rows} rows"
158+
)
159+
assert value_result.num_columns == 1, (
160+
f"Expected a single column result from cast, got {value_result.num_columns} columns"
161+
)
160162

161-
right_val = value_result.column(0)[0]
162-
# This is an interesting behavior, null is returned with an int32 type.
163-
if type(right_val) is pa.Int32Scalar and right_val.as_py() is None:
164-
right_val = pa.scalar(None, type=pa.null())
163+
right_val = value_result.column(0)[0]
164+
# This is an interesting behavior, null is returned with an int32 type.
165+
if type(right_val) is pa.Int32Scalar and right_val.as_py() is None:
166+
right_val = pa.scalar(None, type=pa.null())
165167

166168
left_val = node.left
167169
assert isinstance(left_val, sqlglot.expressions.Column), (
@@ -205,8 +207,8 @@ def _eval_predicate(
205207
return field_info.has_non_nulls
206208

207209
return not (
208-
_scalar_value_eq(field_info.min_value, field_info.max_value)
209-
and _scalar_value_eq(field_info.min_value, right_val)
210+
_sv_eq(field_info.min_value, field_info.max_value)
211+
and _sv_eq(field_info.min_value, right_val)
210212
)
211213

212214
elif type(node) is sqlglot.expressions.NullSafeEQ:
@@ -215,9 +217,9 @@ def _eval_predicate(
215217
if field_info.min_value is None or field_info.max_value is None:
216218
return False
217219
assert not pa.types.is_null(right_val.type)
218-
return _scalar_value_lte(
219-
field_info.min_value, right_val
220-
) and _scalar_value_lte(right_val, field_info.max_value)
220+
return _sv_lte(field_info.min_value, right_val) and _sv_lte(
221+
right_val, field_info.max_value
222+
)
221223

222224
if field_info.min_value is None or field_info.max_value is None:
223225
return False
@@ -227,37 +229,37 @@ def _eval_predicate(
227229

228230
match type(node):
229231
case sqlglot.expressions.EQ:
230-
return _scalar_value_lte(
231-
field_info.min_value, right_val
232-
) and _scalar_value_lte(right_val, field_info.max_value)
232+
return _sv_lte(field_info.min_value, right_val) and _sv_lte(
233+
right_val, field_info.max_value
234+
)
233235
case sqlglot.expressions.NEQ:
234236
return not (
235-
_scalar_value_eq(field_info.min_value, field_info.max_value)
236-
and _scalar_value_eq(field_info.min_value, right_val)
237+
_sv_eq(field_info.min_value, field_info.max_value)
238+
and _sv_eq(field_info.min_value, right_val)
237239
)
238240
case sqlglot.expressions.LT:
239-
return _scalar_value_lt(field_info.min_value, right_val)
241+
return _sv_lt(field_info.min_value, right_val)
240242
case sqlglot.expressions.LTE:
241-
return _scalar_value_lte(field_info.min_value, right_val)
243+
return _sv_lte(field_info.min_value, right_val)
242244
case sqlglot.expressions.GT:
243-
return _scalar_value_gt(field_info.max_value, right_val)
245+
return _sv_gt(field_info.max_value, right_val)
244246
case sqlglot.expressions.GTE:
245-
return _scalar_value_gte(field_info.max_value, right_val)
247+
return _sv_gte(field_info.max_value, right_val)
246248
case sqlglot.expressions.NullSafeEQ:
247249
if pa.types.is_null(right_val.type) and field_info.has_non_nulls:
248250
return True
249-
return _scalar_value_lte(
250-
field_info.min_value, right_val
251-
) and _scalar_value_lte(right_val, field_info.max_value)
251+
return _sv_lte(field_info.min_value, right_val) and _sv_lte(
252+
right_val, field_info.max_value
253+
)
252254
case sqlglot.expressions.NullSafeNEQ:
253255
if (
254256
not pa.types.is_null(right_val.type)
255257
and field_info.has_non_nulls is False
256258
):
257259
return True
258260
return not (
259-
_scalar_value_eq(field_info.min_value, field_info.max_value)
260-
and _scalar_value_eq(field_info.min_value, right_val)
261+
_sv_eq(field_info.min_value, field_info.max_value)
262+
and _sv_eq(field_info.min_value, right_val)
261263
)
262264
case _:
263265
raise ValueError(f"Unsupported operator type: {type(node)}")

0 commit comments

Comments
 (0)