Skip to content

Commit 36fd051

Browse files
authored
Merge branch 'main' into chore/asyncio-optimization
2 parents e07b5b4 + f296a0f commit 36fd051

File tree

4 files changed

+179
-2
lines changed

4 files changed

+179
-2
lines changed

codeflash/tracing/tracing_new_process.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
import time
1515
from collections import defaultdict
16+
from importlib.util import find_spec
1617
from pathlib import Path
1718
from typing import TYPE_CHECKING, Any, Callable, ClassVar
1819

@@ -47,6 +48,17 @@ def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None:
4748
self.f_locals: dict = {}
4849

4950

51+
def patch_ap_scheduler() -> None:
52+
if find_spec("apscheduler"):
53+
import apscheduler.schedulers.background as bg
54+
import apscheduler.schedulers.blocking as bb
55+
from apscheduler.schedulers import base
56+
57+
bg.BackgroundScheduler.start = lambda _, *_a, **_k: None
58+
bb.BlockingScheduler.start = lambda _, *_a, **_k: None
59+
base.BaseScheduler.add_job = lambda _, *_a, **_k: None
60+
61+
5062
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
5163
class Tracer:
5264
"""Use this class as a 'with' context manager to trace a function call.
@@ -820,6 +832,7 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An
820832
if __name__ == "__main__":
821833
args_dict = json.loads(sys.argv[-1])
822834
sys.argv = sys.argv[1:-1]
835+
patch_ap_scheduler()
823836
if args_dict["module"]:
824837
import runpy
825838

codeflash/verification/comparator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,27 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
235235
):
236236
return orig == new
237237

238+
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
239+
orig_dict = {}
240+
new_dict = {}
241+
242+
for attr in orig.__attrs_attrs__:
243+
if attr.eq:
244+
attr_name = attr.name
245+
orig_dict[attr_name] = getattr(orig, attr_name, None)
246+
new_dict[attr_name] = getattr(new, attr_name, None)
247+
248+
if superset_obj:
249+
new_attrs_dict = {}
250+
for attr in new.__attrs_attrs__:
251+
if attr.eq:
252+
attr_name = attr.name
253+
new_attrs_dict[attr_name] = getattr(new, attr_name, None)
254+
return all(
255+
k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items()
256+
)
257+
return comparator(orig_dict, new_dict, superset_obj)
258+
238259
# re.Pattern can be made better by DFA Minimization and then comparing
239260
if isinstance(
240261
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.17.0"
2+
__version__ = "0.16.7.post46.dev0+444ff121"

tests/test_comparator.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1502,4 +1502,147 @@ def test_collections() -> None:
15021502
d = "hello"
15031503
assert comparator(a, b)
15041504
assert not comparator(a, c)
1505-
assert not comparator(a, d)
1505+
assert not comparator(a, d)
1506+
1507+
1508+
def test_attrs():
1509+
try:
1510+
import attrs # type: ignore
1511+
except ImportError:
1512+
pytest.skip()
1513+
1514+
@attrs.define
1515+
class Person:
1516+
name: str
1517+
age: int = 10
1518+
1519+
a = Person("Alice", 25)
1520+
b = Person("Alice", 25)
1521+
c = Person("Bob", 25)
1522+
d = Person("Alice", 30)
1523+
assert comparator(a, b)
1524+
assert not comparator(a, c)
1525+
assert not comparator(a, d)
1526+
1527+
@attrs.frozen
1528+
class Point:
1529+
x: int
1530+
y: int
1531+
1532+
p1 = Point(1, 2)
1533+
p2 = Point(1, 2)
1534+
p3 = Point(2, 3)
1535+
assert comparator(p1, p2)
1536+
assert not comparator(p1, p3)
1537+
1538+
@attrs.define(slots=True)
1539+
class Vehicle:
1540+
brand: str
1541+
model: str
1542+
year: int = 2020
1543+
1544+
v1 = Vehicle("Toyota", "Camry", 2021)
1545+
v2 = Vehicle("Toyota", "Camry", 2021)
1546+
v3 = Vehicle("Honda", "Civic", 2021)
1547+
assert comparator(v1, v2)
1548+
assert not comparator(v1, v3)
1549+
1550+
@attrs.define
1551+
class ComplexClass:
1552+
public_field: str
1553+
private_field: str = attrs.field(repr=False)
1554+
non_eq_field: int = attrs.field(eq=False, default=0)
1555+
computed: str = attrs.field(init=False, eq=True)
1556+
1557+
def __attrs_post_init__(self):
1558+
self.computed = f"{self.public_field}_{self.private_field}"
1559+
1560+
c1 = ComplexClass("test", "secret")
1561+
c2 = ComplexClass("test", "secret")
1562+
c3 = ComplexClass("different", "secret")
1563+
1564+
c1.non_eq_field = 100
1565+
c2.non_eq_field = 200
1566+
1567+
assert comparator(c1, c2)
1568+
assert not comparator(c1, c3)
1569+
1570+
@attrs.define
1571+
class Address:
1572+
street: str
1573+
city: str
1574+
1575+
@attrs.define
1576+
class PersonWithAddress:
1577+
name: str
1578+
address: Address
1579+
1580+
addr1 = Address("123 Main St", "Anytown")
1581+
addr2 = Address("123 Main St", "Anytown")
1582+
addr3 = Address("456 Oak Ave", "Anytown")
1583+
1584+
person1 = PersonWithAddress("John", addr1)
1585+
person2 = PersonWithAddress("John", addr2)
1586+
person3 = PersonWithAddress("John", addr3)
1587+
1588+
assert comparator(person1, person2)
1589+
assert not comparator(person1, person3)
1590+
1591+
@attrs.define
1592+
class Container:
1593+
items: list
1594+
metadata: dict
1595+
1596+
cont1 = Container([1, 2, 3], {"type": "numbers"})
1597+
cont2 = Container([1, 2, 3], {"type": "numbers"})
1598+
cont3 = Container([1, 2, 4], {"type": "numbers"})
1599+
1600+
assert comparator(cont1, cont2)
1601+
assert not comparator(cont1, cont3)
1602+
1603+
@attrs.define
1604+
class BaseClass:
1605+
name: str
1606+
value: int
1607+
1608+
@attrs.define
1609+
class ExtendedClass:
1610+
name: str
1611+
value: int
1612+
extra_field: str = "default"
1613+
1614+
base = BaseClass("test", 42)
1615+
extended = ExtendedClass("test", 42, "extra")
1616+
1617+
assert not comparator(base, extended)
1618+
1619+
@attrs.define
1620+
class WithNonEqFields:
1621+
name: str
1622+
timestamp: float = attrs.field(eq=False) # Should be ignored
1623+
debug_info: str = attrs.field(eq=False, default="debug")
1624+
1625+
obj1 = WithNonEqFields("test", 1000.0, "info1")
1626+
obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields
1627+
obj3 = WithNonEqFields("different", 1000.0, "info1")
1628+
1629+
assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info
1630+
assert not comparator(obj1, obj3) # Should be different due to name
1631+
@attrs.define
1632+
class MinimalClass:
1633+
name: str
1634+
value: int
1635+
1636+
@attrs.define
1637+
class ExtendedClass:
1638+
name: str
1639+
value: int
1640+
extra_field: str = "default"
1641+
metadata: dict = attrs.field(factory=dict)
1642+
timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored
1643+
1644+
minimal = MinimalClass("test", 42)
1645+
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
1646+
1647+
assert not comparator(minimal, extended)
1648+

0 commit comments

Comments
 (0)