Skip to content

Commit 3ba8be6

Browse files
authored
Merge pull request #787 from codeflash-ai/import-time-optimization
small Import time optimization
2 parents aa398f7 + d58134f commit 3ba8be6

File tree

3 files changed

+113
-117
lines changed

3 files changed

+113
-117
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
import enum
56
import hashlib
67
import os
78
import pickle
@@ -11,12 +12,11 @@
1112
import unittest
1213
from collections import defaultdict
1314
from pathlib import Path
14-
from typing import TYPE_CHECKING, Callable, Optional
15+
from typing import TYPE_CHECKING, Callable, Optional, final
1516

1617
if TYPE_CHECKING:
1718
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1819

19-
import pytest
2020
from pydantic.dataclasses import dataclass
2121
from rich.panel import Panel
2222
from rich.text import Text
@@ -35,6 +35,22 @@
3535
from codeflash.verification.verification_utils import TestConfig
3636

3737

38+
@final
39+
class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this
40+
#: Tests passed.
41+
OK = 0
42+
#: Tests failed.
43+
TESTS_FAILED = 1
44+
#: pytest was interrupted.
45+
INTERRUPTED = 2
46+
#: An internal error got in the way.
47+
INTERNAL_ERROR = 3
48+
#: pytest was misused.
49+
USAGE_ERROR = 4
50+
#: pytest couldn't find tests.
51+
NO_TESTS_COLLECTED = 5
52+
53+
3854
@dataclass(frozen=True)
3955
class TestFunction:
4056
function_name: str
@@ -412,15 +428,15 @@ def discover_tests_pytest(
412428
error_section = match.group(1) if match else result.stdout
413429

414430
logger.warning(
415-
f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}"
431+
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
416432
)
417433
if "ModuleNotFoundError" in result.stdout:
418434
match = ImportErrorPattern.search(result.stdout).group()
419435
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
420436
console.print(panel)
421437

422438
elif 0 <= exitcode <= 5:
423-
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}")
439+
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
424440
else:
425441
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
426442
console.rule()

codeflash/verification/comparator.py

Lines changed: 92 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# ruff: noqa: PGH003
21
import array
32
import ast
43
import datetime
@@ -8,65 +7,22 @@
87
import re
98
import types
109
from collections import ChainMap, OrderedDict, deque
10+
from importlib.util import find_spec
1111
from typing import Any
1212

1313
import sentry_sdk
1414

1515
from codeflash.cli_cmds.console import logger
1616
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
1717

18-
try:
19-
import numpy as np
20-
21-
HAS_NUMPY = True
22-
except ImportError:
23-
HAS_NUMPY = False
24-
try:
25-
import sqlalchemy # type: ignore
26-
27-
HAS_SQLALCHEMY = True
28-
except ImportError:
29-
HAS_SQLALCHEMY = False
30-
try:
31-
import scipy # type: ignore
32-
33-
HAS_SCIPY = True
34-
except ImportError:
35-
HAS_SCIPY = False
36-
37-
try:
38-
import pandas # type: ignore # noqa: ICN001
39-
40-
HAS_PANDAS = True
41-
except ImportError:
42-
HAS_PANDAS = False
43-
44-
try:
45-
import pyrsistent # type: ignore
46-
47-
HAS_PYRSISTENT = True
48-
except ImportError:
49-
HAS_PYRSISTENT = False
50-
try:
51-
import torch # type: ignore
52-
53-
HAS_TORCH = True
54-
except ImportError:
55-
HAS_TORCH = False
56-
try:
57-
import jax # type: ignore
58-
import jax.numpy as jnp # type: ignore
59-
60-
HAS_JAX = True
61-
except ImportError:
62-
HAS_JAX = False
63-
64-
try:
65-
import xarray # type: ignore
66-
67-
HAS_XARRAY = True
68-
except ImportError:
69-
HAS_XARRAY = False
18+
HAS_NUMPY = find_spec("numpy") is not None
19+
HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None
20+
HAS_SCIPY = find_spec("scipy") is not None
21+
HAS_PANDAS = find_spec("pandas") is not None
22+
HAS_PYRSISTENT = find_spec("pyrsistent") is not None
23+
HAS_TORCH = find_spec("torch") is not None
24+
HAS_JAX = find_spec("jax") is not None
25+
HAS_XARRAY = find_spec("xarray") is not None
7026

7127

7228
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -122,19 +78,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
12278
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
12379
return comparator(orig_dict, new_dict, superset_obj)
12480

125-
# Handle JAX arrays first to avoid boolean context errors in other conditions
126-
if HAS_JAX and isinstance(orig, jax.Array):
127-
if orig.dtype != new.dtype:
128-
return False
129-
if orig.shape != new.shape:
130-
return False
131-
return bool(jnp.allclose(orig, new, equal_nan=True))
81+
if HAS_JAX:
82+
import jax # type: ignore # noqa: PGH003
83+
import jax.numpy as jnp # type: ignore # noqa: PGH003
84+
85+
# Handle JAX arrays first to avoid boolean context errors in other conditions
86+
if isinstance(orig, jax.Array):
87+
if orig.dtype != new.dtype:
88+
return False
89+
if orig.shape != new.shape:
90+
return False
91+
return bool(jnp.allclose(orig, new, equal_nan=True))
13292

13393
# Handle xarray objects before numpy to avoid boolean context errors
134-
if HAS_XARRAY and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
135-
return orig.identical(new)
94+
if HAS_XARRAY:
95+
import xarray # type: ignore # noqa: PGH003
96+
97+
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
98+
return orig.identical(new)
13699

137100
if HAS_SQLALCHEMY:
101+
import sqlalchemy # type: ignore # noqa: PGH003
102+
138103
try:
139104
insp = sqlalchemy.inspection.inspect(orig)
140105
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
@@ -149,6 +114,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
149114

150115
except sqlalchemy.exc.NoInspectionAvailable:
151116
pass
117+
118+
if HAS_SCIPY:
119+
import scipy # type: ignore # noqa: PGH003
152120
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
153121
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
154122
if superset_obj:
@@ -162,27 +130,30 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
162130
return False
163131
return True
164132

165-
if HAS_NUMPY and isinstance(orig, np.ndarray):
166-
if orig.dtype != new.dtype:
167-
return False
168-
if orig.shape != new.shape:
169-
return False
170-
try:
171-
return np.allclose(orig, new, equal_nan=True)
172-
except Exception:
173-
# fails at "ufunc 'isfinite' not supported for the input types"
174-
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
133+
if HAS_NUMPY:
134+
import numpy as np # type: ignore # noqa: PGH003
175135

176-
if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)):
177-
return np.isclose(orig, new)
136+
if isinstance(orig, np.ndarray):
137+
if orig.dtype != new.dtype:
138+
return False
139+
if orig.shape != new.shape:
140+
return False
141+
try:
142+
return np.allclose(orig, new, equal_nan=True)
143+
except Exception:
144+
# fails at "ufunc 'isfinite' not supported for the input types"
145+
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
178146

179-
if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)):
180-
return orig == new
147+
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
148+
return np.isclose(orig, new)
181149

182-
if HAS_NUMPY and isinstance(orig, np.void):
183-
if orig.dtype != new.dtype:
184-
return False
185-
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
150+
if isinstance(orig, (np.integer, np.bool_, np.byte)):
151+
return orig == new
152+
153+
if isinstance(orig, np.void):
154+
if orig.dtype != new.dtype:
155+
return False
156+
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
186157

187158
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
188159
if orig.dtype != new.dtype:
@@ -191,15 +162,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
191162
return False
192163
return (orig != new).nnz == 0
193164

194-
if HAS_PANDAS and isinstance(
195-
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
196-
):
197-
return orig.equals(new)
165+
if HAS_PANDAS:
166+
import pandas # type: ignore # noqa: ICN001, PGH003
198167

199-
if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
200-
return orig == new
201-
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
202-
return True
168+
if isinstance(
169+
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
170+
):
171+
return orig.equals(new)
172+
173+
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
174+
return orig == new
175+
if pandas.isna(orig) and pandas.isna(new):
176+
return True
203177

204178
if isinstance(orig, array.array):
205179
if orig.typecode != new.typecode:
@@ -220,31 +194,37 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
220194
except Exception: # noqa: S110
221195
pass
222196

223-
if HAS_TORCH and isinstance(orig, torch.Tensor):
224-
if orig.dtype != new.dtype:
225-
return False
226-
if orig.shape != new.shape:
227-
return False
228-
if orig.requires_grad != new.requires_grad:
229-
return False
230-
if orig.device != new.device:
231-
return False
232-
return torch.allclose(orig, new, equal_nan=True)
197+
if HAS_TORCH:
198+
import torch # type: ignore # noqa: PGH003
233199

234-
if HAS_PYRSISTENT and isinstance(
235-
orig,
236-
(
237-
pyrsistent.PMap,
238-
pyrsistent.PVector,
239-
pyrsistent.PSet,
240-
pyrsistent.PRecord,
241-
pyrsistent.PClass,
242-
pyrsistent.PBag,
243-
pyrsistent.PList,
244-
pyrsistent.PDeque,
245-
),
246-
):
247-
return orig == new
200+
if isinstance(orig, torch.Tensor):
201+
if orig.dtype != new.dtype:
202+
return False
203+
if orig.shape != new.shape:
204+
return False
205+
if orig.requires_grad != new.requires_grad:
206+
return False
207+
if orig.device != new.device:
208+
return False
209+
return torch.allclose(orig, new, equal_nan=True)
210+
211+
if HAS_PYRSISTENT:
212+
import pyrsistent # type: ignore # noqa: PGH003
213+
214+
if isinstance(
215+
orig,
216+
(
217+
pyrsistent.PMap,
218+
pyrsistent.PVector,
219+
pyrsistent.PSet,
220+
pyrsistent.PRecord,
221+
pyrsistent.PClass,
222+
pyrsistent.PBag,
223+
pyrsistent.PList,
224+
pyrsistent.PDeque,
225+
),
226+
):
227+
return orig == new
248228

249229
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
250230
orig_dict = {}

tests/scripts/end_to_end_test_bubblesort_unittest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
9-
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.40
9+
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.30
1010
)
1111
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
1212
return run_codeflash_command(cwd, config, expected_improvement_pct)

0 commit comments

Comments
 (0)