1- # ruff: noqa: PGH003
21import array
32import ast
43import datetime
87import re
98import types
109from collections import ChainMap , OrderedDict , deque
10+ from importlib .util import find_spec
1111from typing import Any
1212
1313import sentry_sdk
1414
1515from codeflash .cli_cmds .console import logger
1616from codeflash .picklepatch .pickle_placeholder import PicklePlaceholderAccessError
1717
18- try :
19- import numpy as np
20-
21- HAS_NUMPY = True
22- except ImportError :
23- HAS_NUMPY = False
24- try :
25- import sqlalchemy # type: ignore
26-
27- HAS_SQLALCHEMY = True
28- except ImportError :
29- HAS_SQLALCHEMY = False
30- try :
31- import scipy # type: ignore
32-
33- HAS_SCIPY = True
34- except ImportError :
35- HAS_SCIPY = False
36-
37- try :
38- import pandas # type: ignore # noqa: ICN001
39-
40- HAS_PANDAS = True
41- except ImportError :
42- HAS_PANDAS = False
43-
44- try :
45- import pyrsistent # type: ignore
46-
47- HAS_PYRSISTENT = True
48- except ImportError :
49- HAS_PYRSISTENT = False
50- try :
51- import torch # type: ignore
52-
53- HAS_TORCH = True
54- except ImportError :
55- HAS_TORCH = False
56- try :
57- import jax # type: ignore
58- import jax .numpy as jnp # type: ignore
59-
60- HAS_JAX = True
61- except ImportError :
62- HAS_JAX = False
63-
64- try :
65- import xarray # type: ignore
66-
67- HAS_XARRAY = True
68- except ImportError :
69- HAS_XARRAY = False
18+ HAS_NUMPY = find_spec ("numpy" ) is not None
19+ HAS_SQLALCHEMY = find_spec ("sqlalchemy" ) is not None
20+ HAS_SCIPY = find_spec ("scipy" ) is not None
21+ HAS_PANDAS = find_spec ("pandas" ) is not None
22+ HAS_PYRSISTENT = find_spec ("pyrsistent" ) is not None
23+ HAS_TORCH = find_spec ("torch" ) is not None
24+ HAS_JAX = find_spec ("jax" ) is not None
25+ HAS_XARRAY = find_spec ("xarray" ) is not None
7026
7127
7228def comparator (orig : Any , new : Any , superset_obj = False ) -> bool : # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -122,19 +78,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
12278 new_dict = {k : v for k , v in new .__dict__ .items () if not k .startswith ("_" )}
12379 return comparator (orig_dict , new_dict , superset_obj )
12480
125- # Handle JAX arrays first to avoid boolean context errors in other conditions
126- if HAS_JAX and isinstance (orig , jax .Array ):
127- if orig .dtype != new .dtype :
128- return False
129- if orig .shape != new .shape :
130- return False
131- return bool (jnp .allclose (orig , new , equal_nan = True ))
81+ if HAS_JAX :
82+ import jax # type: ignore # noqa: PGH003
83+ import jax .numpy as jnp # type: ignore # noqa: PGH003
84+
85+ # Handle JAX arrays first to avoid boolean context errors in other conditions
86+ if isinstance (orig , jax .Array ):
87+ if orig .dtype != new .dtype :
88+ return False
89+ if orig .shape != new .shape :
90+ return False
91+ return bool (jnp .allclose (orig , new , equal_nan = True ))
13292
13393 # Handle xarray objects before numpy to avoid boolean context errors
134- if HAS_XARRAY and isinstance (orig , (xarray .Dataset , xarray .DataArray )):
135- return orig .identical (new )
94+ if HAS_XARRAY :
95+ import xarray # type: ignore # noqa: PGH003
96+
97+ if isinstance (orig , (xarray .Dataset , xarray .DataArray )):
98+ return orig .identical (new )
13699
137100 if HAS_SQLALCHEMY :
101+ import sqlalchemy # type: ignore # noqa: PGH003
102+
138103 try :
139104 insp = sqlalchemy .inspection .inspect (orig )
140105 insp = sqlalchemy .inspection .inspect (new ) # noqa: F841
@@ -149,6 +114,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
149114
150115 except sqlalchemy .exc .NoInspectionAvailable :
151116 pass
117+
118+ if HAS_SCIPY :
119+ import scipy # type: ignore # noqa: PGH003
152120 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
153121 if isinstance (orig , dict ) and not (HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix )):
154122 if superset_obj :
@@ -162,27 +130,30 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
162130 return False
163131 return True
164132
165- if HAS_NUMPY and isinstance (orig , np .ndarray ):
166- if orig .dtype != new .dtype :
167- return False
168- if orig .shape != new .shape :
169- return False
170- try :
171- return np .allclose (orig , new , equal_nan = True )
172- except Exception :
173- # fails at "ufunc 'isfinite' not supported for the input types"
174- return np .all ([comparator (x , y , superset_obj ) for x , y in zip (orig , new )])
133+ if HAS_NUMPY :
134+ import numpy as np # type: ignore # noqa: PGH003
175135
176- if HAS_NUMPY and isinstance (orig , (np .floating , np .complex64 , np .complex128 )):
177- return np .isclose (orig , new )
136+ if isinstance (orig , np .ndarray ):
137+ if orig .dtype != new .dtype :
138+ return False
139+ if orig .shape != new .shape :
140+ return False
141+ try :
142+ return np .allclose (orig , new , equal_nan = True )
143+ except Exception :
144+ # fails at "ufunc 'isfinite' not supported for the input types"
145+ return np .all ([comparator (x , y , superset_obj ) for x , y in zip (orig , new )])
178146
179- if HAS_NUMPY and isinstance (orig , (np .integer , np .bool_ , np .byte )):
180- return orig == new
147+ if isinstance (orig , (np .floating , np .complex64 , np .complex128 )):
148+ return np . isclose ( orig , new )
181149
182- if HAS_NUMPY and isinstance (orig , np .void ):
183- if orig .dtype != new .dtype :
184- return False
185- return all (comparator (orig [field ], new [field ], superset_obj ) for field in orig .dtype .fields )
150+ if isinstance (orig , (np .integer , np .bool_ , np .byte )):
151+ return orig == new
152+
153+ if isinstance (orig , np .void ):
154+ if orig .dtype != new .dtype :
155+ return False
156+ return all (comparator (orig [field ], new [field ], superset_obj ) for field in orig .dtype .fields )
186157
187158 if HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix ):
188159 if orig .dtype != new .dtype :
@@ -191,15 +162,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
191162 return False
192163 return (orig != new ).nnz == 0
193164
194- if HAS_PANDAS and isinstance (
195- orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
196- ):
197- return orig .equals (new )
165+ if HAS_PANDAS :
166+ import pandas # type: ignore # noqa: ICN001, PGH003
198167
199- if HAS_PANDAS and isinstance (orig , (pandas .CategoricalDtype , pandas .Interval , pandas .Period )):
200- return orig == new
201- if HAS_PANDAS and pandas .isna (orig ) and pandas .isna (new ):
202- return True
168+ if isinstance (
169+ orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
170+ ):
171+ return orig .equals (new )
172+
173+ if isinstance (orig , (pandas .CategoricalDtype , pandas .Interval , pandas .Period )):
174+ return orig == new
175+ if pandas .isna (orig ) and pandas .isna (new ):
176+ return True
203177
204178 if isinstance (orig , array .array ):
205179 if orig .typecode != new .typecode :
@@ -220,31 +194,37 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
220194 except Exception : # noqa: S110
221195 pass
222196
223- if HAS_TORCH and isinstance (orig , torch .Tensor ):
224- if orig .dtype != new .dtype :
225- return False
226- if orig .shape != new .shape :
227- return False
228- if orig .requires_grad != new .requires_grad :
229- return False
230- if orig .device != new .device :
231- return False
232- return torch .allclose (orig , new , equal_nan = True )
197+ if HAS_TORCH :
198+ import torch # type: ignore # noqa: PGH003
233199
234- if HAS_PYRSISTENT and isinstance (
235- orig ,
236- (
237- pyrsistent .PMap ,
238- pyrsistent .PVector ,
239- pyrsistent .PSet ,
240- pyrsistent .PRecord ,
241- pyrsistent .PClass ,
242- pyrsistent .PBag ,
243- pyrsistent .PList ,
244- pyrsistent .PDeque ,
245- ),
246- ):
247- return orig == new
200+ if isinstance (orig , torch .Tensor ):
201+ if orig .dtype != new .dtype :
202+ return False
203+ if orig .shape != new .shape :
204+ return False
205+ if orig .requires_grad != new .requires_grad :
206+ return False
207+ if orig .device != new .device :
208+ return False
209+ return torch .allclose (orig , new , equal_nan = True )
210+
211+ if HAS_PYRSISTENT :
212+ import pyrsistent # type: ignore # noqa: PGH003
213+
214+ if isinstance (
215+ orig ,
216+ (
217+ pyrsistent .PMap ,
218+ pyrsistent .PVector ,
219+ pyrsistent .PSet ,
220+ pyrsistent .PRecord ,
221+ pyrsistent .PClass ,
222+ pyrsistent .PBag ,
223+ pyrsistent .PList ,
224+ pyrsistent .PDeque ,
225+ ),
226+ ):
227+ return orig == new
248228
249229 if hasattr (orig , "__attrs_attrs__" ) and hasattr (new , "__attrs_attrs__" ):
250230 orig_dict = {}
0 commit comments