Skip to content

Commit 3f5ac4a

Browse files
committed
litellm fix
1 parent 104a5ac commit 3f5ac4a

File tree

2 files changed

+348
-2
lines changed

2 files changed

+348
-2
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,19 +529,28 @@ def add_needed_imports_from_module(
529529
try:
530530
for mod in gatherer.module_imports:
531531
# Skip __future__ imports as they cannot be imported directly
532-
# __future__ imports should only be imported with specific objects
532+
# __future__ imports should only be imported with specific objects i.e from __future__ import annotations
533533
if mod == "__future__":
534534
continue
535535
if mod not in dotted_import_collector.imports:
536536
AddImportsVisitor.add_needed_import(dst_context, mod)
537537
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
538+
aliased_objects = set()
539+
for mod, alias_pairs in gatherer.alias_mapping.items():
540+
for alias_pair in alias_pairs:
541+
if alias_pair[0] and alias_pair[1]: # Both name and alias exist
542+
aliased_objects.add(f"{mod}.{alias_pair[0]}")
543+
538544
for mod, obj_seq in gatherer.object_mapping.items():
539545
for obj in obj_seq:
540546
if (
541547
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
542548
):
543549
continue # Skip adding imports for helper functions already in the context
544550

551+
if f"{mod}.{obj}" in aliased_objects:
552+
continue
553+
545554
# Handle star imports by resolving them to actual symbol names
546555
if obj == "*":
547556
resolved_symbols = resolve_star_import(mod, project_root)
@@ -563,6 +572,8 @@ def add_needed_imports_from_module(
563572
return dst_module_code
564573

565574
for mod, asname in gatherer.module_aliases.items():
575+
if not asname:
576+
continue
566577
if f"{mod}.{asname}" not in dotted_import_collector.imports:
567578
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
568579
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
@@ -572,12 +583,16 @@ def add_needed_imports_from_module(
572583
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
573584
continue
574585

586+
if not alias_pair[0] or not alias_pair[1]:
587+
continue
588+
575589
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
576590
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
577591
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
578592

579593
try:
580-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
594+
add_imports_visitor = AddImportsVisitor(dst_context)
595+
transformed_module = add_imports_visitor.transform_module(parsed_dst_module)
581596
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
582597
return transformed_module.code.lstrip("\n")
583598
except Exception as e:
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
5+
6+
7+
def test_add_needed_imports_with_none_aliases():
8+
source_code = '''
9+
import json
10+
from typing import Dict as MyDict, Optional
11+
from collections import defaultdict
12+
'''
13+
14+
target_code = '''
15+
def target_function():
16+
pass
17+
'''
18+
19+
expected_output = '''
20+
def target_function():
21+
pass
22+
'''
23+
24+
with tempfile.TemporaryDirectory() as temp_dir:
25+
temp_path = Path(temp_dir)
26+
src_path = temp_path / "source.py"
27+
dst_path = temp_path / "target.py"
28+
29+
src_path.write_text(source_code)
30+
dst_path.write_text(target_code)
31+
32+
result = add_needed_imports_from_module(
33+
src_module_code=source_code,
34+
dst_module_code=target_code,
35+
src_path=src_path,
36+
dst_path=dst_path,
37+
project_root=temp_path
38+
)
39+
40+
assert result.strip() == expected_output.strip()
41+
42+
43+
def test_add_needed_imports_complex_aliases():
44+
source_code = '''
45+
import os
46+
import sys as system
47+
from typing import Dict, List as MyList, Optional as Opt
48+
from collections import defaultdict as dd, Counter
49+
from pathlib import Path
50+
'''
51+
52+
target_code = '''
53+
def my_function():
54+
return "test"
55+
'''
56+
57+
expected_output = '''
58+
def my_function():
59+
return "test"
60+
'''
61+
62+
with tempfile.TemporaryDirectory() as temp_dir:
63+
temp_path = Path(temp_dir)
64+
src_path = temp_path / "source.py"
65+
dst_path = temp_path / "target.py"
66+
67+
src_path.write_text(source_code)
68+
dst_path.write_text(target_code)
69+
70+
result = add_needed_imports_from_module(
71+
src_module_code=source_code,
72+
dst_module_code=target_code,
73+
src_path=src_path,
74+
dst_path=dst_path,
75+
project_root=temp_path
76+
)
77+
78+
assert result.strip() == expected_output.strip()
79+
80+
81+
def test_add_needed_imports_with_usage():
82+
source_code = '''
83+
import json
84+
from typing import Dict as MyDict, Optional
85+
from collections import defaultdict
86+
87+
'''
88+
89+
target_code = '''
90+
def target_function():
91+
data = json.loads('{"key": "value"}')
92+
my_dict: MyDict[str, str] = {}
93+
opt_value: Optional[str] = None
94+
dd = defaultdict(list)
95+
return data, my_dict, opt_value, dd
96+
'''
97+
98+
expected_output = '''import json
99+
from typing import Dict as MyDict, Optional
100+
from collections import defaultdict
101+
102+
def target_function():
103+
data = json.loads('{"key": "value"}')
104+
my_dict: MyDict[str, str] = {}
105+
opt_value: Optional[str] = None
106+
dd = defaultdict(list)
107+
return data, my_dict, opt_value, dd
108+
'''
109+
110+
with tempfile.TemporaryDirectory() as temp_dir:
111+
temp_path = Path(temp_dir)
112+
src_path = temp_path / "source.py"
113+
dst_path = temp_path / "target.py"
114+
115+
src_path.write_text(source_code)
116+
dst_path.write_text(target_code)
117+
118+
result = add_needed_imports_from_module(
119+
src_module_code=source_code,
120+
dst_module_code=target_code,
121+
src_path=src_path,
122+
dst_path=dst_path,
123+
project_root=temp_path
124+
)
125+
126+
# Assert exact expected output
127+
assert result.strip() == expected_output.strip()
128+
129+
130+
def test_litellm_router_style_imports():
131+
source_code = '''
132+
import asyncio
133+
import copy
134+
import json
135+
from collections import defaultdict
136+
from typing import Dict, List, Optional, Union
137+
from litellm.types.utils import ModelInfo
138+
from litellm.types.utils import ModelInfo as ModelMapInfo
139+
'''
140+
141+
target_code = '''
142+
def target_function():
143+
"""Target function for testing."""
144+
pass
145+
'''
146+
147+
expected_output = '''
148+
def target_function():
149+
"""Target function for testing."""
150+
pass
151+
'''
152+
153+
with tempfile.TemporaryDirectory() as temp_dir:
154+
temp_path = Path(temp_dir)
155+
src_path = temp_path / "complex_source.py"
156+
dst_path = temp_path / "target.py"
157+
158+
src_path.write_text(source_code)
159+
dst_path.write_text(target_code)
160+
161+
result = add_needed_imports_from_module(
162+
src_module_code=source_code,
163+
dst_module_code=target_code,
164+
src_path=src_path,
165+
dst_path=dst_path,
166+
project_root=temp_path
167+
)
168+
169+
assert result.strip() == expected_output.strip()
170+
171+
172+
def test_edge_case_none_values_in_alias_pairs():
173+
source_code = '''
174+
from typing import Dict as MyDict, List, Optional as Opt
175+
from collections import defaultdict, Counter as cnt
176+
from pathlib import Path
177+
'''
178+
179+
target_code = '''
180+
def my_test_function():
181+
return "test"
182+
'''
183+
184+
expected_output = '''
185+
def my_test_function():
186+
return "test"
187+
'''
188+
189+
with tempfile.TemporaryDirectory() as temp_dir:
190+
temp_path = Path(temp_dir)
191+
src_path = temp_path / "edge_case_source.py"
192+
dst_path = temp_path / "target.py"
193+
194+
src_path.write_text(source_code)
195+
dst_path.write_text(target_code)
196+
197+
result = add_needed_imports_from_module(
198+
src_module_code=source_code,
199+
dst_module_code=target_code,
200+
src_path=src_path,
201+
dst_path=dst_path,
202+
project_root=temp_path
203+
)
204+
205+
assert result.strip() == expected_output.strip()
206+
207+
208+
def test_partial_import_usage():
209+
source_code = '''
210+
import os
211+
import sys
212+
from typing import Dict, List, Optional
213+
from collections import defaultdict, Counter
214+
'''
215+
216+
target_code = '''
217+
def use_some_imports():
218+
path = os.path.join("a", "b")
219+
my_dict: Dict[str, int] = {}
220+
counter = Counter([1, 2, 3])
221+
return path, my_dict, counter
222+
'''
223+
224+
expected_output = '''import os
225+
from collections import Counter
226+
from typing import Dict
227+
228+
def use_some_imports():
229+
path = os.path.join("a", "b")
230+
my_dict: Dict[str, int] = {}
231+
counter = Counter([1, 2, 3])
232+
return path, my_dict, counter
233+
'''
234+
235+
with tempfile.TemporaryDirectory() as temp_dir:
236+
temp_path = Path(temp_dir)
237+
src_path = temp_path / "source.py"
238+
dst_path = temp_path / "target.py"
239+
240+
src_path.write_text(source_code)
241+
dst_path.write_text(target_code)
242+
243+
result = add_needed_imports_from_module(
244+
src_module_code=source_code,
245+
dst_module_code=target_code,
246+
src_path=src_path,
247+
dst_path=dst_path,
248+
project_root=temp_path
249+
)
250+
251+
assert result.strip() == expected_output.strip()
252+
253+
254+
def test_alias_handling():
255+
source_code = '''
256+
from typing import Dict as MyDict, List as MyList, Optional
257+
from collections import defaultdict as dd, Counter
258+
'''
259+
260+
target_code = '''
261+
def test_aliases():
262+
d: MyDict[str, int] = {}
263+
lst: MyList[str] = []
264+
dd_instance = dd(list)
265+
return d, lst, dd_instance
266+
'''
267+
268+
expected_output = '''from collections import defaultdict as dd
269+
from typing import Dict as MyDict, List as MyList
270+
271+
def test_aliases():
272+
d: MyDict[str, int] = {}
273+
lst: MyList[str] = []
274+
dd_instance = dd(list)
275+
return d, lst, dd_instance
276+
'''
277+
278+
with tempfile.TemporaryDirectory() as temp_dir:
279+
temp_path = Path(temp_dir)
280+
src_path = temp_path / "source.py"
281+
dst_path = temp_path / "target.py"
282+
283+
src_path.write_text(source_code)
284+
dst_path.write_text(target_code)
285+
286+
result = add_needed_imports_from_module(
287+
src_module_code=source_code,
288+
dst_module_code=target_code,
289+
src_path=src_path,
290+
dst_path=dst_path,
291+
project_root=temp_path
292+
)
293+
294+
assert result.strip() == expected_output.strip()
295+
296+
def test_add_needed_imports_with_nonealiases():
297+
source_code = '''
298+
import json
299+
from typing import Dict as MyDict, Optional
300+
from collections import defaultdict
301+
302+
'''
303+
304+
target_code = '''
305+
def target_function():
306+
pass
307+
'''
308+
309+
with tempfile.TemporaryDirectory() as temp_dir:
310+
temp_path = Path(temp_dir)
311+
src_path = temp_path / "source.py"
312+
dst_path = temp_path / "target.py"
313+
314+
src_path.write_text(source_code)
315+
dst_path.write_text(target_code)
316+
317+
# This should not raise a TypeError
318+
result = add_needed_imports_from_module(
319+
src_module_code=source_code,
320+
dst_module_code=target_code,
321+
src_path=src_path,
322+
dst_path=dst_path,
323+
project_root=temp_path
324+
)
325+
326+
327+
expected_output = '''
328+
def target_function():
329+
pass
330+
'''
331+
assert result.strip() == expected_output.strip()

0 commit comments

Comments
 (0)