Skip to content

Commit a25f417

Browse files
committed
it makes sense, i promise
1 parent 5fd7738 commit a25f417

File tree

1 file changed

+72
-67
lines changed

1 file changed

+72
-67
lines changed

codeflash/verification/comparator.py

Lines changed: 72 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,7 @@
2222
HAS_PYRSISTENT = find_spec("pyrsistent") is not None
2323
HAS_TORCH = find_spec("torch") is not None
2424
HAS_JAX = find_spec("jax") is not None
25-
26-
try:
27-
import xarray # type: ignore
28-
29-
HAS_XARRAY = True
30-
except ImportError:
31-
HAS_XARRAY = False
25+
HAS_XARRAY = find_spec("xarray") is not None
3226

3327

3428
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -83,24 +77,29 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
8377
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
8478
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
8579
return comparator(orig_dict, new_dict, superset_obj)
80+
8681
if HAS_JAX:
8782
import jax # type: ignore # noqa: PGH003
8883
import jax.numpy as jnp # type: ignore # noqa: PGH003
89-
# Handle JAX arrays first to avoid boolean context errors in other conditions
90-
if HAS_JAX and isinstance(orig, jax.Array):
91-
if orig.dtype != new.dtype:
92-
return False
93-
if orig.shape != new.shape:
94-
return False
95-
return bool(jnp.allclose(orig, new, equal_nan=True))
84+
85+
# Handle JAX arrays first to avoid boolean context errors in other conditions
86+
if isinstance(orig, jax.Array):
87+
if orig.dtype != new.dtype:
88+
return False
89+
if orig.shape != new.shape:
90+
return False
91+
return bool(jnp.allclose(orig, new, equal_nan=True))
9692

9793
# Handle xarray objects before numpy to avoid boolean context errors
98-
if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
99-
return orig.identical(new)
94+
if HAS_XARRAY:
95+
import xarray # type: ignore # noqa: PGH003
96+
97+
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
98+
return orig.identical(new)
10099

101100
if HAS_SQLALCHEMY:
102101
import sqlalchemy # type: ignore # noqa: PGH003
103-
if HAS_SQLALCHEMY:
102+
104103
try:
105104
insp = sqlalchemy.inspection.inspect(orig)
106105
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
@@ -115,6 +114,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
115114

116115
except sqlalchemy.exc.NoInspectionAvailable:
117116
pass
117+
118118
if HAS_SCIPY:
119119
import scipy # type: ignore # noqa: PGH003
120120
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
@@ -132,27 +132,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
132132

133133
if HAS_NUMPY:
134134
import numpy as np # type: ignore # noqa: PGH003
135-
if HAS_NUMPY and isinstance(orig, np.ndarray):
136-
if orig.dtype != new.dtype:
137-
return False
138-
if orig.shape != new.shape:
139-
return False
140-
try:
141-
return np.allclose(orig, new, equal_nan=True)
142-
except Exception:
143-
# fails at "ufunc 'isfinite' not supported for the input types"
144-
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
145135

146-
if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)):
147-
return np.isclose(orig, new)
136+
if isinstance(orig, np.ndarray):
137+
if orig.dtype != new.dtype:
138+
return False
139+
if orig.shape != new.shape:
140+
return False
141+
try:
142+
return np.allclose(orig, new, equal_nan=True)
143+
except Exception:
144+
# fails at "ufunc 'isfinite' not supported for the input types"
145+
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
146+
147+
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
148+
return np.isclose(orig, new)
148149

149-
if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)):
150-
return orig == new
150+
if isinstance(orig, (np.integer, np.bool_, np.byte)):
151+
return orig == new
151152

152-
if HAS_NUMPY and isinstance(orig, np.void):
153-
if orig.dtype != new.dtype:
154-
return False
155-
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
153+
if isinstance(orig, np.void):
154+
if orig.dtype != new.dtype:
155+
return False
156+
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
156157

157158
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
158159
if orig.dtype != new.dtype:
@@ -163,15 +164,16 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
163164

164165
if HAS_PANDAS:
165166
import pandas # type: ignore # noqa: ICN001, PGH003
166-
if HAS_PANDAS and isinstance(
167-
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
168-
):
169-
return orig.equals(new)
170167

171-
if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
172-
return orig == new
173-
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
174-
return True
168+
if isinstance(
169+
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
170+
):
171+
return orig.equals(new)
172+
173+
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
174+
return orig == new
175+
if pandas.isna(orig) and pandas.isna(new):
176+
return True
175177

176178
if isinstance(orig, array.array):
177179
if orig.typecode != new.typecode:
@@ -194,32 +196,35 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
194196

195197
if HAS_TORCH:
196198
import torch # type: ignore # noqa: PGH003
197-
if HAS_TORCH and isinstance(orig, torch.Tensor):
198-
if orig.dtype != new.dtype:
199-
return False
200-
if orig.shape != new.shape:
201-
return False
202-
if orig.requires_grad != new.requires_grad:
203-
return False
204-
if orig.device != new.device:
205-
return False
206-
return torch.allclose(orig, new, equal_nan=True)
199+
200+
if isinstance(orig, torch.Tensor):
201+
if orig.dtype != new.dtype:
202+
return False
203+
if orig.shape != new.shape:
204+
return False
205+
if orig.requires_grad != new.requires_grad:
206+
return False
207+
if orig.device != new.device:
208+
return False
209+
return torch.allclose(orig, new, equal_nan=True)
210+
207211
if HAS_PYRSISTENT:
208212
import pyrsistent # type: ignore # noqa: PGH003
209-
if HAS_PYRSISTENT and isinstance(
210-
orig,
211-
(
212-
pyrsistent.PMap,
213-
pyrsistent.PVector,
214-
pyrsistent.PSet,
215-
pyrsistent.PRecord,
216-
pyrsistent.PClass,
217-
pyrsistent.PBag,
218-
pyrsistent.PList,
219-
pyrsistent.PDeque,
220-
),
221-
):
222-
return orig == new
213+
214+
if isinstance(
215+
orig,
216+
(
217+
pyrsistent.PMap,
218+
pyrsistent.PVector,
219+
pyrsistent.PSet,
220+
pyrsistent.PRecord,
221+
pyrsistent.PClass,
222+
pyrsistent.PBag,
223+
pyrsistent.PList,
224+
pyrsistent.PDeque,
225+
),
226+
):
227+
return orig == new
223228

224229
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
225230
orig_dict = {}

0 commit comments

Comments
 (0)