Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import Counter, defaultdict
from functools import lru_cache
from typing import TYPE_CHECKING

import libcst as cst
Expand All @@ -13,6 +14,7 @@

if TYPE_CHECKING:
from collections.abc import Iterator

import enum
import re
import sys
Expand All @@ -23,11 +25,13 @@
from typing import Annotated, Optional, cast

from jedi.api.classes import Name
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
from pydantic import (AfterValidator, BaseModel, ConfigDict, PrivateAttr,
ValidationError)
from pydantic.dataclasses import dataclass

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code
from codeflash.code_utils.code_utils import (module_name_from_file_path,
validate_python_code)
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.verification.comparator import comparator

Expand Down Expand Up @@ -513,23 +517,22 @@ def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Option
return None

def get_src_code(self, test_path: Path) -> Optional[str]:
if not test_path.exists():
return None
test_src = test_path.read_text(encoding="utf-8")
module_node = cst.parse_module(test_src)

if self.test_class_name:
for stmt in module_node.body:
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
func_node = self.find_func_in_class(stmt, self.test_function_name)
if func_node:
return module_node.code_for_node(func_node).strip()
# class not found
module_node = self._parse_module_by_path(str(test_path))
if module_node is None:
return None

test_func_name = self.test_function_name
test_class_name = self.test_class_name
found_func = None

# Otherwise, look for a top level function
for stmt in module_node.body:
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name:
if test_class_name is not None and isinstance(stmt, cst.ClassDef) and stmt.name.value == test_class_name:
found_func = self.find_func_in_class(stmt, test_func_name)
if found_func:
return module_node.code_for_node(found_func).strip()
return None # Class found but function not found
if test_class_name is None and isinstance(stmt, cst.FunctionDef) and stmt.name.value == test_func_name:
return module_node.code_for_node(stmt).strip()
return None

Expand All @@ -552,6 +555,17 @@ def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId
iteration_id=iteration_id if iteration_id else components[3],
)

# All attribute definitions are preserved

@staticmethod
@lru_cache(maxsize=32)
def _parse_module_by_path(test_path_str: str) -> Optional[cst.Module]:
path = Path(test_path_str)
if not path.exists():
return None
test_src = path.read_text(encoding="utf-8")
return cst.parse_module(test_src)


@dataclass(frozen=True)
class FunctionTestInvocation:
Expand Down Expand Up @@ -631,7 +645,8 @@ def get_all_ids(self) -> set[InvocationId]:
return {test_result.id for test_result in self.test_results}

def get_all_unique_invocation_loop_ids(self) -> set[str]:
return {test_result.unique_invocation_loop_id for test_result in self.test_results}
# generator expression for memory efficiency
return set(tr.unique_invocation_loop_id for tr in self.test_results)

def number_of_loops(self) -> int:
if not self.test_results:
Expand Down
127 changes: 70 additions & 57 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import sentry_sdk

from codeflash.cli_cmds.console import logger
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
from codeflash.picklepatch.pickle_placeholder import \
PicklePlaceholderAccessError

HAS_NUMPY = find_spec("numpy") is not None
HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None
Expand All @@ -34,11 +35,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
return False
if isinstance(orig, (list, tuple, deque, ChainMap)):
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))

# Cheap, common types first
if isinstance(
orig,
(
Expand All @@ -65,6 +63,14 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
if isinstance(orig, (list, tuple, deque, ChainMap)):
if len(orig) != len(new):
return False
for elem1, elem2 in zip(orig, new):
if not comparator(elem1, elem2, superset_obj):
return False
return True

if isinstance(orig, BaseException):
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
Expand All @@ -78,15 +84,16 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
return comparator(orig_dict, new_dict, superset_obj)

# JAX, XARRAY, NUMPY, PANDAS, TORCH modules imported once per function call if needed
np = None
pandas = None
if HAS_JAX:
import jax # type: ignore # noqa: PGH003
import jax.numpy as jnp # type: ignore # noqa: PGH003

# Handle JAX arrays first to avoid boolean context errors in other conditions
if isinstance(orig, jax.Array):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
if orig.dtype != new.dtype or orig.shape != new.shape:
return False
return bool(jnp.allclose(orig, new, equal_nan=True))

Expand All @@ -101,11 +108,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
import sqlalchemy # type: ignore # noqa: PGH003

try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
sqlalchemy.inspection.inspect(orig)
sqlalchemy.inspection.inspect(new)
orig_keys = orig.__dict__
new_keys = new.__dict__
for key in list(orig_keys.keys()):
for key in orig_keys:
if key.startswith("_"):
continue
if key not in new_keys or not comparator(orig_keys[key], new_keys[key], superset_obj):
Expand All @@ -117,32 +124,36 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001

if HAS_SCIPY:
import scipy # type: ignore # noqa: PGH003
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):

# Dict support/Sparse
is_sparse = HAS_SCIPY and "scipy" in globals() and isinstance(orig, scipy.sparse.spmatrix)
if isinstance(orig, dict) and not is_sparse:
if superset_obj:
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
for k, v in orig.items():
if k not in new or not comparator(v, new[k], superset_obj):
return False
return True
# Strict equality check
if len(orig) != len(new):
return False
for key in orig:
if key not in new:
return False
if not comparator(orig[key], new[key], superset_obj):
for k, v in orig.items():
if k not in new or not comparator(v, new[k], superset_obj):
return False
return True

if HAS_NUMPY:
import numpy as np # type: ignore # noqa: PGH003

if isinstance(orig, np.ndarray):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
if orig.dtype != new.dtype or orig.shape != new.shape:
return False
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
for x, y in zip(orig, new):
if not comparator(x, y, superset_obj):
return False
return True

if isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
Expand All @@ -153,12 +164,24 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if isinstance(orig, np.void):
if orig.dtype != new.dtype:
return False
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
for field in orig.dtype.fields:
if not comparator(orig[field], new[field], superset_obj):
return False
return True
# nan/inf for numpy base types
try:
if np.isnan(orig):
return np.isnan(new)
except Exception:
pass
try:
if np.isinf(orig):
return np.isinf(new)
except Exception:
pass

if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
if orig.dtype != new.dtype:
return False
if orig.get_shape() != new.get_shape():
if is_sparse:
if orig.dtype != new.dtype or orig.get_shape() != new.get_shape():
return False
return (orig != new).nnz == 0

Expand All @@ -176,35 +199,23 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return True

if isinstance(orig, array.array):
if orig.typecode != new.typecode:
return False
if len(orig) != len(new):
if orig.typecode != new.typecode or len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))

# This should be at the end of all numpy checking
try:
if HAS_NUMPY and np.isnan(orig):
return np.isnan(new)
except Exception: # noqa: S110
pass
try:
if HAS_NUMPY and np.isinf(orig):
return np.isinf(new)
except Exception: # noqa: S110
pass
for elem1, elem2 in zip(orig, new):
if not comparator(elem1, elem2, superset_obj):
return False
return True

if HAS_TORCH:
import torch # type: ignore # noqa: PGH003

if isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
if orig.requires_grad != new.requires_grad:
return False
if orig.device != new.device:
if (
orig.dtype != new.dtype
or orig.shape != new.shape
or orig.requires_grad != new.requires_grad
or orig.device != new.device
):
return False
return torch.allclose(orig, new, equal_nan=True)

Expand Down Expand Up @@ -242,12 +253,12 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if attr.eq:
attr_name = attr.name
new_attrs_dict[attr_name] = getattr(new, attr_name, None)
return all(
k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items()
)
for k, v in orig_dict.items():
if k not in new_attrs_dict or not comparator(v, new_attrs_dict[k], superset_obj):
return False
return True
return comparator(orig_dict, new_dict, superset_obj)

# re.Pattern can be made better by DFA Minimization and then comparing
if isinstance(
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)
):
Expand Down Expand Up @@ -275,8 +286,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")}

if superset_obj:
# allow new object to be a superset of the original object
return all(k in new_keys and comparator(v, new_keys[k], superset_obj) for k, v in orig_keys.items())
for k, v in orig_keys.items():
if k not in new_keys or not comparator(v, new_keys[k], superset_obj):
return False
return True

if isinstance(orig, ast.AST):
orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}
Expand Down
Loading
Loading