Skip to content

Commit 5fd7738

Browse files
committed
Merge branch 'main' into import-time-optimization
2 parents fcd96fe + bb35a60 commit 5fd7738

File tree

4 files changed

+138
-4
lines changed

4 files changed

+138
-4
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codeflash.code_utils import env_utils
1212
from codeflash.code_utils.code_utils import exit_with_message
1313
from codeflash.code_utils.config_parser import parse_config_file
14+
from codeflash.code_utils.git_utils import git_root_dir
1415
from codeflash.lsp.helpers import is_LSP_enabled
1516
from codeflash.version import __version__ as version
1617

@@ -222,18 +223,20 @@ def process_pyproject_config(args: Namespace) -> Namespace:
222223
args.module_root = Path(args.module_root).resolve()
223224
# If module-root is "." then all imports are relatives to it.
224225
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
225-
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path)
226+
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path, args.worktree)
226227
args.tests_root = Path(args.tests_root).resolve()
227228
if args.benchmarks_root:
228229
args.benchmarks_root = Path(args.benchmarks_root).resolve()
229-
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
230+
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path, args.worktree)
230231
if is_LSP_enabled():
231232
args.all = None
232233
return args
233234
return handle_optimize_all_arg_parsing(args)
234235

235236

236-
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
237+
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path, in_worktree: bool = False) -> Path: # noqa: FBT001, FBT002
238+
if in_worktree:
239+
return git_root_dir()
237240
if pyproject_file_path.parent == module_root:
238241
return module_root
239242
return module_root.parent.resolve()

codeflash/verification/comparator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323
HAS_TORCH = find_spec("torch") is not None
2424
HAS_JAX = find_spec("jax") is not None
2525

26+
try:
27+
import xarray # type: ignore
28+
29+
HAS_XARRAY = True
30+
except ImportError:
31+
HAS_XARRAY = False
32+
2633

2734
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
2835
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
@@ -87,6 +94,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
8794
return False
8895
return bool(jnp.allclose(orig, new, equal_nan=True))
8996

97+
# Handle xarray objects before numpy to avoid boolean context errors
98+
if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
99+
return orig.identical(new)
100+
90101
if HAS_SQLALCHEMY:
91102
import sqlalchemy # type: ignore # noqa: PGH003
92103
if HAS_SQLALCHEMY:

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.17.1"
2+
__version__ = "0.17.2"

tests/test_comparator.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,126 @@ def test_jax():
787787
assert not comparator(aa, cc)
788788

789789

790+
def test_xarray():
791+
try:
792+
import xarray as xr
793+
import numpy as np
794+
except ImportError:
795+
pytest.skip()
796+
797+
# Test basic DataArray
798+
a = xr.DataArray([1, 2, 3], dims=['x'])
799+
b = xr.DataArray([1, 2, 3], dims=['x'])
800+
c = xr.DataArray([1, 2, 4], dims=['x'])
801+
assert comparator(a, b)
802+
assert not comparator(a, c)
803+
804+
# Test DataArray with coordinates
805+
d = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
806+
e = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
807+
f = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 3]}, dims=['x'])
808+
assert comparator(d, e)
809+
assert not comparator(d, f)
810+
811+
# Test DataArray with attributes
812+
g = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
813+
h = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
814+
i = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'feet'})
815+
assert comparator(g, h)
816+
assert not comparator(g, i)
817+
818+
# Test 2D DataArray
819+
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
820+
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
821+
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=['x', 'y'])
822+
assert comparator(j, k)
823+
assert not comparator(j, l)
824+
825+
# Test DataArray with different dimensions
826+
m = xr.DataArray([1, 2, 3], dims=['x'])
827+
n = xr.DataArray([1, 2, 3], dims=['y'])
828+
assert not comparator(m, n)
829+
830+
# Test DataArray with NaN values
831+
o = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
832+
p = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
833+
q = xr.DataArray([1.0, 2.0, 3.0], dims=['x'])
834+
assert comparator(o, p)
835+
assert not comparator(o, q)
836+
837+
# Test Dataset
838+
r = xr.Dataset({
839+
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
840+
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
841+
})
842+
s = xr.Dataset({
843+
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
844+
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
845+
})
846+
t = xr.Dataset({
847+
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
848+
'pressure': (['x', 'y'], [[5, 6], [7, 9]])
849+
})
850+
assert comparator(r, s)
851+
assert not comparator(r, t)
852+
853+
# Test Dataset with coordinates
854+
u = xr.Dataset({
855+
'temp': (['x', 'y'], [[1, 2], [3, 4]])
856+
}, coords={'x': [0, 1], 'y': [0, 1]})
857+
v = xr.Dataset({
858+
'temp': (['x', 'y'], [[1, 2], [3, 4]])
859+
}, coords={'x': [0, 1], 'y': [0, 1]})
860+
w = xr.Dataset({
861+
'temp': (['x', 'y'], [[1, 2], [3, 4]])
862+
}, coords={'x': [0, 2], 'y': [0, 1]})
863+
assert comparator(u, v)
864+
assert not comparator(u, w)
865+
866+
# Test Dataset with attributes
867+
x = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
868+
y = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
869+
z = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'model'})
870+
assert comparator(x, y)
871+
assert not comparator(x, z)
872+
873+
# Test Dataset with different variables
874+
aa = xr.Dataset({'temp': (['x'], [1, 2, 3])})
875+
bb = xr.Dataset({'temp': (['x'], [1, 2, 3])})
876+
cc = xr.Dataset({'pressure': (['x'], [1, 2, 3])})
877+
assert comparator(aa, bb)
878+
assert not comparator(aa, cc)
879+
880+
# Test empty Dataset
881+
dd = xr.Dataset()
882+
ee = xr.Dataset()
883+
assert comparator(dd, ee)
884+
885+
# Test DataArray with different shapes
886+
ff = xr.DataArray([1, 2, 3], dims=['x'])
887+
gg = xr.DataArray([[1, 2, 3]], dims=['x', 'y'])
888+
assert not comparator(ff, gg)
889+
890+
# Test DataArray with different data types
891+
# Note: xarray.identical() considers int and float arrays with same values as identical
892+
hh = xr.DataArray(np.array([1, 2, 3], dtype='int32'), dims=['x'])
893+
ii = xr.DataArray(np.array([1, 2, 3], dtype='int64'), dims=['x'])
894+
# xarray is permissive with dtype comparisons, treats these as identical
895+
assert comparator(hh, ii)
896+
897+
# Test DataArray with infinity
898+
jj = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
899+
kk = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
900+
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=['x'])
901+
assert comparator(jj, kk)
902+
assert not comparator(jj, ll)
903+
904+
# Test Dataset vs DataArray (different types)
905+
mm = xr.DataArray([1, 2, 3], dims=['x'])
906+
nn = xr.Dataset({'data': (['x'], [1, 2, 3])})
907+
assert not comparator(mm, nn)
908+
909+
790910
def test_returns():
791911
a = Success(5)
792912
b = Success(5)

0 commit comments

Comments
 (0)