11from __future__ import annotations
22
33import ast
4+ import platform
45from dataclasses import dataclass
56from pathlib import Path
67from typing import TYPE_CHECKING
@@ -318,6 +319,7 @@ def iter_ast_calls(node): # noqa: ANN202, ANN001
318319 return return_statement
319320
320321 def visit_ClassDef (self , node : ast .ClassDef ) -> ast .ClassDef :
322+ # TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
321323 for inner_node in ast .walk (node ):
322324 if isinstance (inner_node , ast .FunctionDef ):
323325 self .visit_FunctionDef (inner_node , node .name )
@@ -327,6 +329,17 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
327329 def visit_FunctionDef (self , node : ast .FunctionDef , test_class_name : str | None = None ) -> ast .FunctionDef :
328330 if node .name .startswith ("test_" ):
329331 did_update = False
332+ if self .test_framework == "unittest" and platform .system () != "Windows" :
333+ # Only add timeout decorator on non-Windows platforms
334+ # Windows doesn't support SIGALRM signal required by timeout_decorator
335+
336+ node .decorator_list .append (
337+ ast .Call (
338+ func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
339+ args = [ast .Constant (value = 15 )],
340+ keywords = [],
341+ )
342+ )
330343 i = len (node .body ) - 1
331344 while i >= 0 :
332345 line_node = node .body [i ]
@@ -492,6 +505,25 @@ def __init__(
492505 self .class_name = function .top_level_parent_name
493506
494507 def visit_ClassDef (self , node : ast .ClassDef ) -> ast .ClassDef :
508+ # Add timeout decorator for unittest test classes if needed
509+ if self .test_framework == "unittest" :
510+ timeout_decorator = ast .Call (
511+ func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
512+ args = [ast .Constant (value = 15 )],
513+ keywords = [],
514+ )
515+ for item in node .body :
516+ if (
517+ isinstance (item , ast .FunctionDef )
518+ and item .name .startswith ("test_" )
519+ and not any (
520+ isinstance (d , ast .Call )
521+ and isinstance (d .func , ast .Name )
522+ and d .func .id == "timeout_decorator.timeout"
523+ for d in item .decorator_list
524+ )
525+ ):
526+ item .decorator_list .append (timeout_decorator )
495527 return self .generic_visit (node )
496528
497529 def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> ast .AsyncFunctionDef :
@@ -510,6 +542,25 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
510542 def _process_test_function (
511543 self , node : ast .AsyncFunctionDef | ast .FunctionDef
512544 ) -> ast .AsyncFunctionDef | ast .FunctionDef :
545+ # Optimize the search for decorator presence
546+ if self .test_framework == "unittest" :
547+ found_timeout = False
548+ for d in node .decorator_list :
549+ # Avoid isinstance(d.func, ast.Name) if d is not ast.Call
550+ if isinstance (d , ast .Call ):
551+ f = d .func
552+ # Avoid attribute lookup if f is not ast.Name
553+ if isinstance (f , ast .Name ) and f .id == "timeout_decorator.timeout" :
554+ found_timeout = True
555+ break
556+ if not found_timeout :
557+ timeout_decorator = ast .Call (
558+ func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
559+ args = [ast .Constant (value = 15 )],
560+ keywords = [],
561+ )
562+ node .decorator_list .append (timeout_decorator )
563+
513564 # Initialize counter for this test function
514565 if node .name not in self .async_call_counter :
515566 self .async_call_counter [node .name ] = 0
@@ -664,6 +715,8 @@ def inject_async_profiling_into_existing_test(
664715
665716 # Add necessary imports
666717 new_imports = [ast .Import (names = [ast .alias (name = "os" )])]
718+ if test_framework == "unittest" :
719+ new_imports .append (ast .Import (names = [ast .alias (name = "timeout_decorator" )]))
667720
668721 tree .body = [* new_imports , * tree .body ]
669722 return True , sort_imports (ast .unparse (tree ), float_to_top = True )
@@ -709,6 +762,8 @@ def inject_profiling_into_existing_test(
709762 ast .Import (names = [ast .alias (name = "dill" , asname = "pickle" )]),
710763 ]
711764 )
765+ if test_framework == "unittest" and platform .system () != "Windows" :
766+ new_imports .append (ast .Import (names = [ast .alias (name = "timeout_decorator" )]))
712767 additional_functions = [create_wrapper_function (mode )]
713768
714769 tree .body = [* new_imports , * additional_functions , * tree .body ]
0 commit comments