Skip to content

Commit 0b16f8a

Browse files
authored
fx lcb metric (#981)
1 parent 16c2630 commit 0b16f8a

File tree

1 file changed

+28
-26
lines changed
  • src/lighteval/tasks/extended/lcb

1 file changed

+28
-26
lines changed

src/lighteval/tasks/extended/lcb/main.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from aenum import extend_enum
3737

3838
from lighteval.metrics.metrics import Metrics, SampleLevelMetric
39+
from lighteval.metrics.metrics_sample import SampleLevelComputation
3940
from lighteval.models.model_output import ModelResponse
4041
from lighteval.tasks.extended.lcb.codegen_metrics import (
4142
codegen_metrics,
@@ -80,38 +81,39 @@ def lcb_codegeneration_prompt_fn(line, task_name: str = "lcb:codegeneration") ->
8081
)
8182

8283

83-
def codegen_metric(model_response: ModelResponse, doc: Doc, **kwargs) -> float:
84-
"""Estimates the Pass@1 metric for the code generation task.
85-
Extract the code from each prediction, Runs it for each sample and generations,
86-
and computes the Pass@1 over the outputs.
87-
"""
88-
assert doc.specific is not None, "Doc specific field is required for codegen_metric"
89-
90-
predictions = model_response.final_text
91-
# Extract generated code snippets
92-
generated_code_snippets = [[extract_code(pred) for pred in predictions]] # noqa: F841
93-
evaluation_sample = { # noqa: F841
94-
"inputs": doc.specific["inputs"],
95-
"outputs": doc.specific["outputs"],
96-
"fn_name": doc.specific["fn_name"],
97-
}
98-
# This is a list of lists because
99-
evaluation_sample = [{"input_output": json.dumps(evaluation_sample)}]
100-
101-
metrics, _ = codegen_metrics(
102-
evaluation_sample,
103-
generated_code_snippets,
104-
k_list=[1], # Only run for Pass@1
105-
num_process_evaluate=8,
106-
)
107-
return metrics["pass@1"]
84+
class CodegenMetric(SampleLevelComputation):
85+
def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> dict:
86+
"""Estimates the Pass@1 metric for the code generation task.
87+
Extract the code from each prediction, Runs it for each sample and generations,
88+
and computes the Pass@1 over the outputs.
89+
"""
90+
assert doc.specific is not None, "Doc specific field is required for codegen_metric"
91+
92+
predictions = model_response.final_text
93+
# Extract generated code snippets
94+
generated_code_snippets = [[extract_code(pred) for pred in predictions]] # noqa: F841
95+
evaluation_sample = { # noqa: F841
96+
"inputs": doc.specific["inputs"],
97+
"outputs": doc.specific["outputs"],
98+
"fn_name": doc.specific["fn_name"],
99+
}
100+
# This is a list of lists because
101+
evaluation_sample = [{"input_output": json.dumps(evaluation_sample)}]
102+
103+
metrics, _ = codegen_metrics(
104+
evaluation_sample,
105+
generated_code_snippets,
106+
k_list=[1], # Only run for Pass@1
107+
num_process_evaluate=8,
108+
)
109+
return metrics["pass@1"]
108110

109111

110112
lcb_codegen_metric = SampleLevelMetric(
111113
metric_name="codegen_pass@1:16", # This is the way of informing the number of generations currently
112114
category=SamplingMethod.GENERATIVE,
113115
higher_is_better=True,
114-
sample_level_fn=codegen_metric,
116+
sample_level_fn=CodegenMetric(),
115117
corpus_level_fn=np.mean,
116118
batched_compute=False,
117119
)

0 commit comments

Comments
 (0)