2222HAS_PYRSISTENT = find_spec ("pyrsistent" ) is not None
2323HAS_TORCH = find_spec ("torch" ) is not None
2424HAS_JAX = find_spec ("jax" ) is not None
25-
26- try :
27- import xarray # type: ignore
28-
29- HAS_XARRAY = True
30- except ImportError :
31- HAS_XARRAY = False
25+ HAS_XARRAY = find_spec ("xarray" ) is not None
3226
3327
3428def comparator (orig : Any , new : Any , superset_obj = False ) -> bool : # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -83,24 +77,29 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
8377 orig_dict = {k : v for k , v in orig .__dict__ .items () if not k .startswith ("_" )}
8478 new_dict = {k : v for k , v in new .__dict__ .items () if not k .startswith ("_" )}
8579 return comparator (orig_dict , new_dict , superset_obj )
80+
8681 if HAS_JAX :
8782 import jax # type: ignore # noqa: PGH003
8883 import jax .numpy as jnp # type: ignore # noqa: PGH003
89- # Handle JAX arrays first to avoid boolean context errors in other conditions
90- if HAS_JAX and isinstance (orig , jax .Array ):
91- if orig .dtype != new .dtype :
92- return False
93- if orig .shape != new .shape :
94- return False
95- return bool (jnp .allclose (orig , new , equal_nan = True ))
84+
85+ # Handle JAX arrays first to avoid boolean context errors in other conditions
86+ if isinstance (orig , jax .Array ):
87+ if orig .dtype != new .dtype :
88+ return False
89+ if orig .shape != new .shape :
90+ return False
91+ return bool (jnp .allclose (orig , new , equal_nan = True ))
9692
9793 # Handle xarray objects before numpy to avoid boolean context errors
98- if HAS_XARRAY and isinstance (orig , (xarray .Dataset , xarray .DataArray )):
99- return orig .identical (new )
94+ if HAS_XARRAY :
95+ import xarray # type: ignore # noqa: PGH003
96+
97+ if isinstance (orig , (xarray .Dataset , xarray .DataArray )):
98+ return orig .identical (new )
10099
101100 if HAS_SQLALCHEMY :
102101 import sqlalchemy # type: ignore # noqa: PGH003
103- if HAS_SQLALCHEMY :
102+
104103 try :
105104 insp = sqlalchemy .inspection .inspect (orig )
106105 insp = sqlalchemy .inspection .inspect (new ) # noqa: F841
@@ -115,6 +114,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
115114
116115 except sqlalchemy .exc .NoInspectionAvailable :
117116 pass
117+
118118 if HAS_SCIPY :
119119 import scipy # type: ignore # noqa: PGH003
120120 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
@@ -132,27 +132,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
132132
133133 if HAS_NUMPY :
134134 import numpy as np # type: ignore # noqa: PGH003
135- if HAS_NUMPY and isinstance (orig , np .ndarray ):
136- if orig .dtype != new .dtype :
137- return False
138- if orig .shape != new .shape :
139- return False
140- try :
141- return np .allclose (orig , new , equal_nan = True )
142- except Exception :
143- # fails at "ufunc 'isfinite' not supported for the input types"
144- return np .all ([comparator (x , y , superset_obj ) for x , y in zip (orig , new )])
145135
146- if HAS_NUMPY and isinstance (orig , (np .floating , np .complex64 , np .complex128 )):
147- return np .isclose (orig , new )
136+ if isinstance (orig , np .ndarray ):
137+ if orig .dtype != new .dtype :
138+ return False
139+ if orig .shape != new .shape :
140+ return False
141+ try :
142+ return np .allclose (orig , new , equal_nan = True )
143+ except Exception :
144+ # fails at "ufunc 'isfinite' not supported for the input types"
145+ return np .all ([comparator (x , y , superset_obj ) for x , y in zip (orig , new )])
146+
147+ if isinstance (orig , (np .floating , np .complex64 , np .complex128 )):
148+ return np .isclose (orig , new )
148149
149- if HAS_NUMPY and isinstance (orig , (np .integer , np .bool_ , np .byte )):
150- return orig == new
150+ if isinstance (orig , (np .integer , np .bool_ , np .byte )):
151+ return orig == new
151152
152- if HAS_NUMPY and isinstance (orig , np .void ):
153- if orig .dtype != new .dtype :
154- return False
155- return all (comparator (orig [field ], new [field ], superset_obj ) for field in orig .dtype .fields )
153+ if isinstance (orig , np .void ):
154+ if orig .dtype != new .dtype :
155+ return False
156+ return all (comparator (orig [field ], new [field ], superset_obj ) for field in orig .dtype .fields )
156157
157158 if HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix ):
158159 if orig .dtype != new .dtype :
@@ -163,15 +164,16 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
163164
164165 if HAS_PANDAS :
165166 import pandas # type: ignore # noqa: ICN001, PGH003
166- if HAS_PANDAS and isinstance (
167- orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
168- ):
169- return orig .equals (new )
170167
171- if HAS_PANDAS and isinstance (orig , (pandas .CategoricalDtype , pandas .Interval , pandas .Period )):
172- return orig == new
173- if HAS_PANDAS and pandas .isna (orig ) and pandas .isna (new ):
174- return True
168+ if isinstance (
169+ orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
170+ ):
171+ return orig .equals (new )
172+
173+ if isinstance (orig , (pandas .CategoricalDtype , pandas .Interval , pandas .Period )):
174+ return orig == new
175+ if pandas .isna (orig ) and pandas .isna (new ):
176+ return True
175177
176178 if isinstance (orig , array .array ):
177179 if orig .typecode != new .typecode :
@@ -194,32 +196,35 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
194196
195197 if HAS_TORCH :
196198 import torch # type: ignore # noqa: PGH003
197- if HAS_TORCH and isinstance (orig , torch .Tensor ):
198- if orig .dtype != new .dtype :
199- return False
200- if orig .shape != new .shape :
201- return False
202- if orig .requires_grad != new .requires_grad :
203- return False
204- if orig .device != new .device :
205- return False
206- return torch .allclose (orig , new , equal_nan = True )
199+
200+ if isinstance (orig , torch .Tensor ):
201+ if orig .dtype != new .dtype :
202+ return False
203+ if orig .shape != new .shape :
204+ return False
205+ if orig .requires_grad != new .requires_grad :
206+ return False
207+ if orig .device != new .device :
208+ return False
209+ return torch .allclose (orig , new , equal_nan = True )
210+
207211 if HAS_PYRSISTENT :
208212 import pyrsistent # type: ignore # noqa: PGH003
209- if HAS_PYRSISTENT and isinstance (
210- orig ,
211- (
212- pyrsistent .PMap ,
213- pyrsistent .PVector ,
214- pyrsistent .PSet ,
215- pyrsistent .PRecord ,
216- pyrsistent .PClass ,
217- pyrsistent .PBag ,
218- pyrsistent .PList ,
219- pyrsistent .PDeque ,
220- ),
221- ):
222- return orig == new
213+
214+ if isinstance (
215+ orig ,
216+ (
217+ pyrsistent .PMap ,
218+ pyrsistent .PVector ,
219+ pyrsistent .PSet ,
220+ pyrsistent .PRecord ,
221+ pyrsistent .PClass ,
222+ pyrsistent .PBag ,
223+ pyrsistent .PList ,
224+ pyrsistent .PDeque ,
225+ ),
226+ ):
227+ return orig == new
223228
224229 if hasattr (orig , "__attrs_attrs__" ) and hasattr (new , "__attrs_attrs__" ):
225230 orig_dict = {}
0 commit comments