Skip to content

Commit a474c22

Browse files
validating generated optimized code
Signed-off-by: ali <mohammed18200118@gmail.com>
1 parent b77f50e commit a474c22

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

codeflash/api/aiservice.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,18 @@ def optimize_python_code( # noqa: D417
135135
console.rule()
136136
end_time = time.perf_counter()
137137
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
return [
139-
OptimizedCandidate(
140-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
141-
explanation=opt["explanation"],
142-
optimization_id=opt["optimization_id"],
138+
candidates = []
139+
140+
for opt in optimizations_json:
141+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
142+
if not code.code_strings:
143+
continue
144+
candidates.append(
145+
OptimizedCandidate(
146+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
147+
)
143148
)
144-
for opt in optimizations_json
145-
]
149+
return candidates
146150
try:
147151
error = response.json()["error"]
148152
except Exception:
@@ -205,14 +209,17 @@ def optimize_python_code_line_profiler( # noqa: D417
205209
optimizations_json = response.json()["optimizations"]
206210
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
207211
console.rule()
208-
return [
209-
OptimizedCandidate(
210-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
211-
explanation=opt["explanation"],
212-
optimization_id=opt["optimization_id"],
212+
candidates = []
213+
for opt in optimizations_json:
214+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
215+
if not code.code_strings:
216+
continue
217+
candidates.append(
218+
OptimizedCandidate(
219+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
220+
)
213221
)
214-
for opt in optimizations_json
215-
]
222+
return candidates
216223
try:
217224
error = response.json()["error"]
218225
except Exception:
@@ -262,14 +269,19 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
262269
refined_optimizations = response.json()["refinements"]
263270
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
264271
console.rule()
265-
return [
266-
OptimizedCandidate(
267-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
268-
explanation=opt["explanation"],
269-
optimization_id=opt["optimization_id"][:-4] + "refi",
272+
candidates = []
273+
for opt in refined_optimizations:
274+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
275+
if not code.code_strings:
276+
continue
277+
candidates.append(
278+
OptimizedCandidate(
279+
source_code=code,
280+
explanation=opt["explanation"],
281+
optimization_id=opt["optimization_id"][:-4] + "refi",
282+
)
270283
)
271-
for opt in refined_optimizations
272-
]
284+
return candidates
273285
try:
274286
error = response.json()["error"]
275287
except Exception:

codeflash/models/models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -239,10 +239,14 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
239239
"""
240240
matches = markdown_pattern.findall(markdown_code)
241241
results = CodeStringsMarkdown()
242-
for file_path, code in matches:
243-
path = file_path.strip()
244-
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
245-
return results
242+
try:
243+
for file_path, code in matches:
244+
path = file_path.strip()
245+
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
246+
return results # noqa: TRY300
247+
except ValidationError:
248+
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
249+
return CodeStringsMarkdown()
246250

247251

248252
class CodeOptimizationContext(BaseModel):

0 commit comments

Comments
 (0)