Skip to content

Commit 2b2fd42

Browse files
committed
Revert "first pass"
This reverts commit b507770.
1 parent 1c2fb36 commit 2b2fd42

File tree

9 files changed

+343
-17
lines changed

9 files changed

+343
-17
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def collect_setup_info() -> CLISetupInfo:
472472
for choice in test_framework_choices:
473473
if choice == "pytest":
474474
framework_choices.append(("🧪 pytest", "pytest"))
475+
elif choice == "unittest":
476+
framework_choices.append(("🐍 unittest", "unittest"))
475477

476478
framework_questions = [
477479
inquirer.List(
@@ -615,6 +617,31 @@ def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]:
615617
return test_framework
616618

617619

620+
def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]:
621+
test_framework = None
622+
# Check if any python files contain a class that inherits from unittest.TestCase
623+
for filename in tests_root.iterdir():
624+
if filename.suffix == ".py":
625+
with filename.open(encoding="utf8") as file:
626+
contents = file.read()
627+
try:
628+
node = ast.parse(contents)
629+
except SyntaxError:
630+
continue
631+
if any(
632+
isinstance(item, ast.ClassDef)
633+
and any(
634+
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
635+
or (isinstance(base, ast.Name) and base.id == "TestCase")
636+
for base in item.bases
637+
)
638+
for item in node.body
639+
):
640+
test_framework = "unittest"
641+
break
642+
return test_framework
643+
644+
618645
def check_for_toml_or_setup_file() -> str | None:
619646
click.echo()
620647
click.echo("Checking for pyproject.toml or setup.py…\r", nl=False)
@@ -1298,7 +1325,26 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
12981325
arr[j + 1] = temp
12991326
return arr
13001327
"""
1301-
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
1328+
if args.test_framework == "unittest":
1329+
bubble_sort_test_content = f"""import unittest
1330+
from {os.path.basename(args.module_root)}.bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022
1331+
1332+
class TestBubbleSort(unittest.TestCase):
1333+
def test_sort(self):
1334+
input = [5, 4, 3, 2, 1, 0]
1335+
output = sorter(input)
1336+
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1337+
1338+
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1339+
output = sorter(input)
1340+
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1341+
1342+
input = list(reversed(range(100)))
1343+
output = sorter(input)
1344+
self.assertEqual(output, list(range(100)))
1345+
""" # noqa: PTH119
1346+
elif args.test_framework == "pytest":
1347+
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
13021348
13031349
def test_sort():
13041350
input = [5, 4, 3, 2, 1, 0]

codeflash/code_utils/concolic_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,44 @@ def _transform_assert_line(self, line: str) -> Optional[str]:
2828
expression = expression.rstrip(",;")
2929
return f"{indent}{expression}"
3030

31+
unittest_match = self.unittest_re.match(line)
32+
if unittest_match:
33+
indent, assert_method, args = unittest_match.groups()
34+
35+
if args:
36+
arg_parts = self._split_top_level_args(args)
37+
if arg_parts and arg_parts[0]:
38+
return f"{indent}{arg_parts[0]}"
39+
3140
return None
3241

42+
def _split_top_level_args(self, args_str: str) -> list[str]:
43+
result = []
44+
current = []
45+
depth = 0
46+
47+
for char in args_str:
48+
if char in "([{":
49+
depth += 1
50+
current.append(char)
51+
elif char in ")]}":
52+
depth -= 1
53+
current.append(char)
54+
elif char == "," and depth == 0:
55+
result.append("".join(current).strip())
56+
current = []
57+
else:
58+
current.append(char)
59+
60+
if current:
61+
result.append("".join(current).strip())
62+
63+
return result
64+
3365
def __init__(self) -> None:
3466
# Pre-compiling regular expressions for faster execution
3567
self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$")
68+
self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$")
3669

3770

3871
def clean_concolic_tests(test_suite_code: str) -> str:

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import platform
45
from dataclasses import dataclass
56
from pathlib import Path
67
from typing import TYPE_CHECKING
@@ -318,6 +319,7 @@ def iter_ast_calls(node): # noqa: ANN202, ANN001
318319
return return_statement
319320

320321
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
322+
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
321323
for inner_node in ast.walk(node):
322324
if isinstance(inner_node, ast.FunctionDef):
323325
self.visit_FunctionDef(inner_node, node.name)
@@ -327,6 +329,17 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
327329
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
328330
if node.name.startswith("test_"):
329331
did_update = False
332+
if self.test_framework == "unittest" and platform.system() != "Windows":
333+
# Only add timeout decorator on non-Windows platforms
334+
# Windows doesn't support SIGALRM signal required by timeout_decorator
335+
336+
node.decorator_list.append(
337+
ast.Call(
338+
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
339+
args=[ast.Constant(value=15)],
340+
keywords=[],
341+
)
342+
)
330343
i = len(node.body) - 1
331344
while i >= 0:
332345
line_node = node.body[i]
@@ -492,6 +505,25 @@ def __init__(
492505
self.class_name = function.top_level_parent_name
493506

494507
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
508+
# Add timeout decorator for unittest test classes if needed
509+
if self.test_framework == "unittest":
510+
timeout_decorator = ast.Call(
511+
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
512+
args=[ast.Constant(value=15)],
513+
keywords=[],
514+
)
515+
for item in node.body:
516+
if (
517+
isinstance(item, ast.FunctionDef)
518+
and item.name.startswith("test_")
519+
and not any(
520+
isinstance(d, ast.Call)
521+
and isinstance(d.func, ast.Name)
522+
and d.func.id == "timeout_decorator.timeout"
523+
for d in item.decorator_list
524+
)
525+
):
526+
item.decorator_list.append(timeout_decorator)
495527
return self.generic_visit(node)
496528

497529
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
@@ -510,6 +542,25 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
510542
def _process_test_function(
511543
self, node: ast.AsyncFunctionDef | ast.FunctionDef
512544
) -> ast.AsyncFunctionDef | ast.FunctionDef:
545+
# Optimize the search for decorator presence
546+
if self.test_framework == "unittest":
547+
found_timeout = False
548+
for d in node.decorator_list:
549+
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
550+
if isinstance(d, ast.Call):
551+
f = d.func
552+
# Avoid attribute lookup if f is not ast.Name
553+
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
554+
found_timeout = True
555+
break
556+
if not found_timeout:
557+
timeout_decorator = ast.Call(
558+
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
559+
args=[ast.Constant(value=15)],
560+
keywords=[],
561+
)
562+
node.decorator_list.append(timeout_decorator)
563+
513564
# Initialize counter for this test function
514565
if node.name not in self.async_call_counter:
515566
self.async_call_counter[node.name] = 0
@@ -664,6 +715,8 @@ def inject_async_profiling_into_existing_test(
664715

665716
# Add necessary imports
666717
new_imports = [ast.Import(names=[ast.alias(name="os")])]
718+
if test_framework == "unittest":
719+
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
667720

668721
tree.body = [*new_imports, *tree.body]
669722
return True, sort_imports(ast.unparse(tree), float_to_top=True)
@@ -709,6 +762,8 @@ def inject_profiling_into_existing_test(
709762
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
710763
]
711764
)
765+
if test_framework == "unittest" and platform.system() != "Windows":
766+
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
712767
additional_functions = [create_wrapper_function(mode)]
713768

714769
tree.body = [*new_imports, *additional_functions, *tree.body]

0 commit comments

Comments
 (0)