Skip to content

Commit 8753e54

Browse files
authored
Merge pull request #705 from codeflash-ai/deque-comparator
Deque Comparator
2 parents 2a1096b + b7a52bc commit 8753e54

File tree

4 files changed

+114
-2
lines changed

4 files changed

+114
-2
lines changed

codeflash/code_utils/git_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from functools import cache
1010
from io import StringIO
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Optional
1313

1414
import git
1515
from rich.prompt import Confirm

codeflash/lsp/beta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ProvideApiKeyParams:
5353
class OnPatchAppliedParams:
5454
patch_id: str
5555

56+
5657
@dataclass
5758
class OptimizableFunctionsInCommitParams:
5859
commit_hash: str

codeflash/verification/comparator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import re
99
import types
10+
from collections import ChainMap, OrderedDict, deque
1011
from typing import Any
1112

1213
import sentry_sdk
@@ -70,7 +71,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
7071
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
7172
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
7273
return False
73-
if isinstance(orig, (list, tuple)):
74+
if isinstance(orig, (list, tuple, deque, ChainMap)):
7475
if len(orig) != len(new):
7576
return False
7677
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
@@ -93,6 +94,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
9394
enum.Enum,
9495
type,
9596
range,
97+
OrderedDict,
9698
),
9799
):
98100
return orig == new

tests/test_comparator.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import datetime
55
import decimal
66
import re
7+
from collections import ChainMap, Counter, UserDict, UserList, UserString, defaultdict, deque, namedtuple, OrderedDict
8+
79
import sys
810
import uuid
911
from enum import Enum, Flag, IntFlag, auto
@@ -1394,3 +1396,110 @@ def raise_specific_exception():
13941396
module2 = ast.parse(code2)
13951397

13961398
assert not comparator(module7, module2)
1399+
1400+
def test_collections() -> None:
1401+
# Deque
1402+
a = deque([1, 2, 3])
1403+
b = deque([1, 2, 3])
1404+
c = deque([1, 2, 4])
1405+
d = deque([1, 2])
1406+
e = [1, 2, 3]
1407+
f = deque([1, 2, 3], maxlen=5)
1408+
assert comparator(a, b)
1409+
assert comparator(a, f) # same elements, different maxlen is ok
1410+
assert not comparator(a, c)
1411+
assert not comparator(a, d)
1412+
assert not comparator(a, e)
1413+
1414+
g = deque([{"a": 1}, {"b": 2}])
1415+
h = deque([{"a": 1}, {"b": 2}])
1416+
i = deque([{"a": 1}, {"b": 3}])
1417+
assert comparator(g, h)
1418+
assert not comparator(g, i)
1419+
1420+
empty_deque1 = deque()
1421+
empty_deque2 = deque()
1422+
assert comparator(empty_deque1, empty_deque2)
1423+
assert not comparator(empty_deque1, a)
1424+
1425+
# namedtuple
1426+
Point = namedtuple('Point', ['x', 'y'])
1427+
a = Point(x=1, y=2)
1428+
b = Point(x=1, y=2)
1429+
c = Point(x=1, y=3)
1430+
assert comparator(a, b)
1431+
assert not comparator(a, c)
1432+
1433+
Point2 = namedtuple('Point2', ['x', 'y'])
1434+
d = Point2(x=1, y=2)
1435+
assert not comparator(a, d)
1436+
1437+
e = (1, 2)
1438+
assert not comparator(a, e)
1439+
1440+
# ChainMap
1441+
map1 = {'a': 1, 'b': 2}
1442+
map2 = {'c': 3, 'd': 4}
1443+
a = ChainMap(map1, map2)
1444+
b = ChainMap(map1, map2)
1445+
c = ChainMap(map2, map1)
1446+
d = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
1447+
assert comparator(a, b)
1448+
assert not comparator(a, c)
1449+
assert not comparator(a, d)
1450+
1451+
# Counter
1452+
a = Counter(['a', 'b', 'a', 'c', 'b', 'a'])
1453+
b = Counter({'a': 3, 'b': 2, 'c': 1})
1454+
c = Counter({'a': 3, 'b': 2, 'c': 2})
1455+
d = {'a': 3, 'b': 2, 'c': 1}
1456+
assert comparator(a, b)
1457+
assert not comparator(a, c)
1458+
assert not comparator(a, d)
1459+
1460+
# OrderedDict
1461+
a = OrderedDict([('a', 1), ('b', 2)])
1462+
b = OrderedDict([('a', 1), ('b', 2)])
1463+
c = OrderedDict([('b', 2), ('a', 1)])
1464+
d = {'a': 1, 'b': 2}
1465+
assert comparator(a, b)
1466+
assert not comparator(a, c)
1467+
assert not comparator(a, d)
1468+
1469+
# defaultdict
1470+
a = defaultdict(int, {'a': 1, 'b': 2})
1471+
b = defaultdict(int, {'a': 1, 'b': 2})
1472+
c = defaultdict(list, {'a': 1, 'b': 2})
1473+
d = {'a': 1, 'b': 2}
1474+
e = defaultdict(int, {'a': 1, 'b': 3})
1475+
assert comparator(a, b)
1476+
assert comparator(a, c)
1477+
assert not comparator(a, d)
1478+
assert not comparator(a, e)
1479+
1480+
# UserDict
1481+
a = UserDict({'a': 1, 'b': 2})
1482+
b = UserDict({'a': 1, 'b': 2})
1483+
c = UserDict({'a': 1, 'b': 3})
1484+
d = {'a': 1, 'b': 2}
1485+
assert comparator(a, b)
1486+
assert not comparator(a, c)
1487+
assert not comparator(a, d)
1488+
1489+
# UserList
1490+
a = UserList([1, 2, 3])
1491+
b = UserList([1, 2, 3])
1492+
c = UserList([1, 2, 4])
1493+
d = [1, 2, 3]
1494+
assert comparator(a, b)
1495+
assert not comparator(a, c)
1496+
assert not comparator(a, d)
1497+
1498+
# UserString
1499+
a = UserString("hello")
1500+
b = UserString("hello")
1501+
c = UserString("world")
1502+
d = "hello"
1503+
assert comparator(a, b)
1504+
assert not comparator(a, c)
1505+
assert not comparator(a, d)

0 commit comments

Comments
 (0)