1+ import tempfile
2+ from pathlib import Path
3+
4+ from codeflash .code_utils .code_extractor import add_needed_imports_from_module
5+
6+
7+ def test_add_needed_imports_with_none_aliases ():
8+ source_code = '''
9+ import json
10+ from typing import Dict as MyDict, Optional
11+ from collections import defaultdict
12+ '''
13+
14+ target_code = '''
15+ def target_function():
16+ pass
17+ '''
18+
19+ expected_output = '''
20+ def target_function():
21+ pass
22+ '''
23+
24+ with tempfile .TemporaryDirectory () as temp_dir :
25+ temp_path = Path (temp_dir )
26+ src_path = temp_path / "source.py"
27+ dst_path = temp_path / "target.py"
28+
29+ src_path .write_text (source_code )
30+ dst_path .write_text (target_code )
31+
32+ result = add_needed_imports_from_module (
33+ src_module_code = source_code ,
34+ dst_module_code = target_code ,
35+ src_path = src_path ,
36+ dst_path = dst_path ,
37+ project_root = temp_path
38+ )
39+
40+ assert result .strip () == expected_output .strip ()
41+
42+
43+ def test_add_needed_imports_complex_aliases ():
44+ source_code = '''
45+ import os
46+ import sys as system
47+ from typing import Dict, List as MyList, Optional as Opt
48+ from collections import defaultdict as dd, Counter
49+ from pathlib import Path
50+ '''
51+
52+ target_code = '''
53+ def my_function():
54+ return "test"
55+ '''
56+
57+ expected_output = '''
58+ def my_function():
59+ return "test"
60+ '''
61+
62+ with tempfile .TemporaryDirectory () as temp_dir :
63+ temp_path = Path (temp_dir )
64+ src_path = temp_path / "source.py"
65+ dst_path = temp_path / "target.py"
66+
67+ src_path .write_text (source_code )
68+ dst_path .write_text (target_code )
69+
70+ result = add_needed_imports_from_module (
71+ src_module_code = source_code ,
72+ dst_module_code = target_code ,
73+ src_path = src_path ,
74+ dst_path = dst_path ,
75+ project_root = temp_path
76+ )
77+
78+ assert result .strip () == expected_output .strip ()
79+
80+
81+ def test_add_needed_imports_with_usage ():
82+ source_code = '''
83+ import json
84+ from typing import Dict as MyDict, Optional
85+ from collections import defaultdict
86+
87+ '''
88+
89+ target_code = '''
90+ def target_function():
91+ data = json.loads('{"key": "value"}')
92+ my_dict: MyDict[str, str] = {}
93+ opt_value: Optional[str] = None
94+ dd = defaultdict(list)
95+ return data, my_dict, opt_value, dd
96+ '''
97+
98+ expected_output = '''import json
99+ from typing import Dict as MyDict, Optional
100+ from collections import defaultdict
101+
102+ def target_function():
103+ data = json.loads('{"key": "value"}')
104+ my_dict: MyDict[str, str] = {}
105+ opt_value: Optional[str] = None
106+ dd = defaultdict(list)
107+ return data, my_dict, opt_value, dd
108+ '''
109+
110+ with tempfile .TemporaryDirectory () as temp_dir :
111+ temp_path = Path (temp_dir )
112+ src_path = temp_path / "source.py"
113+ dst_path = temp_path / "target.py"
114+
115+ src_path .write_text (source_code )
116+ dst_path .write_text (target_code )
117+
118+ result = add_needed_imports_from_module (
119+ src_module_code = source_code ,
120+ dst_module_code = target_code ,
121+ src_path = src_path ,
122+ dst_path = dst_path ,
123+ project_root = temp_path
124+ )
125+
126+ # Assert exact expected output
127+ assert result .strip () == expected_output .strip ()
128+
129+
130+ def test_litellm_router_style_imports ():
131+ source_code = '''
132+ import asyncio
133+ import copy
134+ import json
135+ from collections import defaultdict
136+ from typing import Dict, List, Optional, Union
137+ from litellm.types.utils import ModelInfo
138+ from litellm.types.utils import ModelInfo as ModelMapInfo
139+ '''
140+
141+ target_code = '''
142+ def target_function():
143+ """Target function for testing."""
144+ pass
145+ '''
146+
147+ expected_output = '''
148+ def target_function():
149+ """Target function for testing."""
150+ pass
151+ '''
152+
153+ with tempfile .TemporaryDirectory () as temp_dir :
154+ temp_path = Path (temp_dir )
155+ src_path = temp_path / "complex_source.py"
156+ dst_path = temp_path / "target.py"
157+
158+ src_path .write_text (source_code )
159+ dst_path .write_text (target_code )
160+
161+ result = add_needed_imports_from_module (
162+ src_module_code = source_code ,
163+ dst_module_code = target_code ,
164+ src_path = src_path ,
165+ dst_path = dst_path ,
166+ project_root = temp_path
167+ )
168+
169+ assert result .strip () == expected_output .strip ()
170+
171+
172+ def test_edge_case_none_values_in_alias_pairs ():
173+ source_code = '''
174+ from typing import Dict as MyDict, List, Optional as Opt
175+ from collections import defaultdict, Counter as cnt
176+ from pathlib import Path
177+ '''
178+
179+ target_code = '''
180+ def my_test_function():
181+ return "test"
182+ '''
183+
184+ expected_output = '''
185+ def my_test_function():
186+ return "test"
187+ '''
188+
189+ with tempfile .TemporaryDirectory () as temp_dir :
190+ temp_path = Path (temp_dir )
191+ src_path = temp_path / "edge_case_source.py"
192+ dst_path = temp_path / "target.py"
193+
194+ src_path .write_text (source_code )
195+ dst_path .write_text (target_code )
196+
197+ result = add_needed_imports_from_module (
198+ src_module_code = source_code ,
199+ dst_module_code = target_code ,
200+ src_path = src_path ,
201+ dst_path = dst_path ,
202+ project_root = temp_path
203+ )
204+
205+ assert result .strip () == expected_output .strip ()
206+
207+
208+ def test_partial_import_usage ():
209+ source_code = '''
210+ import os
211+ import sys
212+ from typing import Dict, List, Optional
213+ from collections import defaultdict, Counter
214+ '''
215+
216+ target_code = '''
217+ def use_some_imports():
218+ path = os.path.join("a", "b")
219+ my_dict: Dict[str, int] = {}
220+ counter = Counter([1, 2, 3])
221+ return path, my_dict, counter
222+ '''
223+
224+ expected_output = '''import os
225+ from collections import Counter
226+ from typing import Dict
227+
228+ def use_some_imports():
229+ path = os.path.join("a", "b")
230+ my_dict: Dict[str, int] = {}
231+ counter = Counter([1, 2, 3])
232+ return path, my_dict, counter
233+ '''
234+
235+ with tempfile .TemporaryDirectory () as temp_dir :
236+ temp_path = Path (temp_dir )
237+ src_path = temp_path / "source.py"
238+ dst_path = temp_path / "target.py"
239+
240+ src_path .write_text (source_code )
241+ dst_path .write_text (target_code )
242+
243+ result = add_needed_imports_from_module (
244+ src_module_code = source_code ,
245+ dst_module_code = target_code ,
246+ src_path = src_path ,
247+ dst_path = dst_path ,
248+ project_root = temp_path
249+ )
250+
251+ assert result .strip () == expected_output .strip ()
252+
253+
254+ def test_alias_handling ():
255+ source_code = '''
256+ from typing import Dict as MyDict, List as MyList, Optional
257+ from collections import defaultdict as dd, Counter
258+ '''
259+
260+ target_code = '''
261+ def test_aliases():
262+ d: MyDict[str, int] = {}
263+ lst: MyList[str] = []
264+ dd_instance = dd(list)
265+ return d, lst, dd_instance
266+ '''
267+
268+ expected_output = '''from collections import defaultdict as dd
269+ from typing import Dict as MyDict, List as MyList
270+
271+ def test_aliases():
272+ d: MyDict[str, int] = {}
273+ lst: MyList[str] = []
274+ dd_instance = dd(list)
275+ return d, lst, dd_instance
276+ '''
277+
278+ with tempfile .TemporaryDirectory () as temp_dir :
279+ temp_path = Path (temp_dir )
280+ src_path = temp_path / "source.py"
281+ dst_path = temp_path / "target.py"
282+
283+ src_path .write_text (source_code )
284+ dst_path .write_text (target_code )
285+
286+ result = add_needed_imports_from_module (
287+ src_module_code = source_code ,
288+ dst_module_code = target_code ,
289+ src_path = src_path ,
290+ dst_path = dst_path ,
291+ project_root = temp_path
292+ )
293+
294+ assert result .strip () == expected_output .strip ()
295+
296+ def test_add_needed_imports_with_nonealiases ():
297+ source_code = '''
298+ import json
299+ from typing import Dict as MyDict, Optional
300+ from collections import defaultdict
301+
302+ '''
303+
304+ target_code = '''
305+ def target_function():
306+ pass
307+ '''
308+
309+ with tempfile .TemporaryDirectory () as temp_dir :
310+ temp_path = Path (temp_dir )
311+ src_path = temp_path / "source.py"
312+ dst_path = temp_path / "target.py"
313+
314+ src_path .write_text (source_code )
315+ dst_path .write_text (target_code )
316+
317+ # This should not raise a TypeError
318+ result = add_needed_imports_from_module (
319+ src_module_code = source_code ,
320+ dst_module_code = target_code ,
321+ src_path = src_path ,
322+ dst_path = dst_path ,
323+ project_root = temp_path
324+ )
325+
326+
327+ expected_output = '''
328+ def target_function():
329+ pass
330+ '''
331+ assert result .strip () == expected_output .strip ()
0 commit comments