Skip to content

Commit a4dc3ef

Browse files
refactoring and test
1 parent a474c22 commit a4dc3ef

File tree

3 files changed

+54
-35
lines changed

3 files changed

+54
-35
lines changed

codeflash/api/aiservice.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def make_ai_service_request(
8181
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8282
return response
8383

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8497
def optimize_python_code( # noqa: D417
8598
self,
8699
source_code: str,
@@ -135,18 +148,7 @@ def optimize_python_code( # noqa: D417
135148
console.rule()
136149
end_time = time.perf_counter()
137150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
candidates = []
139-
140-
for opt in optimizations_json:
141-
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
142-
if not code.code_strings:
143-
continue
144-
candidates.append(
145-
OptimizedCandidate(
146-
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
147-
)
148-
)
149-
return candidates
151+
return self._get_valid_candidates(optimizations_json)
150152
try:
151153
error = response.json()["error"]
152154
except Exception:
@@ -209,17 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
209211
optimizations_json = response.json()["optimizations"]
210212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
211213
console.rule()
212-
candidates = []
213-
for opt in optimizations_json:
214-
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
215-
if not code.code_strings:
216-
continue
217-
candidates.append(
218-
OptimizedCandidate(
219-
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
220-
)
221-
)
222-
return candidates
214+
return self._get_valid_candidates(optimizations_json)
223215
try:
224216
error = response.json()["error"]
225217
except Exception:
@@ -269,19 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
269261
refined_optimizations = response.json()["refinements"]
270262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
271263
console.rule()
272-
candidates = []
273-
for opt in refined_optimizations:
274-
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
275-
if not code.code_strings:
276-
continue
277-
candidates.append(
278-
OptimizedCandidate(
279-
source_code=code,
280-
explanation=opt["explanation"],
281-
optimization_id=opt["optimization_id"][:-4] + "refi",
282-
)
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
266+
return [
267+
OptimizedCandidate(
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
283271
)
284-
return candidates
272+
for c in refinements
273+
]
274+
285275
try:
286276
error = response.json()["error"]
287277
except Exception:

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,7 @@ def process_review(
13001300
return
13011301

13021302
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
1303+
logger.info("Reverting code and helpers...")
13031304
self.write_code_and_helpers(
13041305
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
13051306
)

tests/test_validate_python_code.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from pydantic import ValidationError
33

4+
from codeflash.api.aiservice import AiServiceClient
45
from codeflash.models.models import CodeString
56

67

@@ -41,3 +42,30 @@ def test_whitespace_only():
4142
whitespace_code = " "
4243
cs = CodeString(code=whitespace_code)
4344
assert cs.code == whitespace_code
45+
46+
def test_generated_candidates_validation():
47+
ai_service = AiServiceClient()
48+
code = """```python:file.py
49+
print name
50+
```"""
51+
mock_generate_candidates = [
52+
{
53+
"source_code": code,
54+
"explanation": "",
55+
"optimization_id": ""
56+
}
57+
]
58+
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
59+
assert len(candidates) == 0
60+
code = """```python:file.py
61+
print('Hello, World!')
62+
```"""
63+
mock_generate_candidates = [
64+
{
65+
"source_code": code,
66+
"explanation": "",
67+
"optimization_id": ""
68+
}
69+
]
70+
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
71+
assert len(candidates) == 1

0 commit comments

Comments
 (0)