Skip to content

Commit b17c8f1

Browse files
ankursharmascopybara-github
authored andcommitted
chore: Marked expected_invocation as optional field on evaluator interface
ADK already has a set of metrics that don't rely expected_invocations. Also, for eval cases with conversation scenario, this would be the main line case. PiperOrigin-RevId: 825101481
1 parent 9ab17f2 commit b17c8f1

15 files changed

+281
-101
lines changed

src/google/adk/cli/cli_eval.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,21 +210,23 @@ def pretty_print_eval_result(eval_result: EvalCaseResult):
210210

211211
data = []
212212
for per_invocation_result in eval_result.eval_metric_result_per_invocation:
213+
actual_invocation = per_invocation_result.actual_invocation
214+
expected_invocation = per_invocation_result.expected_invocation
213215
row_data = {
214-
"prompt": _convert_content_to_text(
215-
per_invocation_result.expected_invocation.user_content
216-
),
216+
"prompt": _convert_content_to_text(actual_invocation.user_content),
217217
"expected_response": _convert_content_to_text(
218-
per_invocation_result.expected_invocation.final_response
218+
expected_invocation.final_response if expected_invocation else None
219219
),
220220
"actual_response": _convert_content_to_text(
221-
per_invocation_result.actual_invocation.final_response
221+
actual_invocation.final_response
222222
),
223223
"expected_tool_calls": _convert_tool_calls_to_text(
224-
per_invocation_result.expected_invocation.intermediate_data
224+
expected_invocation.intermediate_data
225+
if expected_invocation
226+
else None
225227
),
226228
"actual_tool_calls": _convert_tool_calls_to_text(
227-
per_invocation_result.actual_invocation.intermediate_data
229+
actual_invocation.intermediate_data
228230
),
229231
}
230232
for metric_result in per_invocation_result.eval_metric_results:

src/google/adk/evaluation/eval_metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,11 @@ class EvalMetricResultPerInvocation(EvalBaseModel):
216216
)
217217
)
218218

219-
expected_invocation: Invocation = Field(
219+
expected_invocation: Optional[Invocation] = Field(
220+
default=None,
220221
description=(
221222
"The expected invocation, usually the reference or golden invocation."
222-
)
223+
),
223224
)
224225

225226
eval_metric_results: list[EvalMetricResult] = Field(

src/google/adk/evaluation/evaluator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class PerInvocationResult(BaseModel):
3333
"""Metric evaluation score per invocation."""
3434

3535
actual_invocation: Invocation
36-
expected_invocation: Invocation
36+
expected_invocation: Optional[Invocation] = None
3737
score: Optional[float] = None
3838
eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
3939
rubric_scores: Optional[list[RubricScore]] = None
@@ -61,7 +61,16 @@ class Evaluator(ABC):
6161
def evaluate_invocations(
6262
self,
6363
actual_invocations: list[Invocation],
64-
expected_invocations: list[Invocation],
64+
expected_invocations: Optional[list[Invocation]],
6565
) -> EvaluationResult:
66-
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
66+
"""Returns EvaluationResult after performing evaluations using actual and expected invocations.
67+
68+
Args:
69+
actual_invocations: These are the invocations that are obtained from the
70+
agent under test.
71+
expected_invocations: An optional list of invocations, if specified,
72+
usually act as a benchmark/golden response. If these are specified
73+
usually the expectation is that the length of this list and actual
74+
invocaiton is the same.
75+
"""
6776
raise NotImplementedError()

src/google/adk/evaluation/final_response_match_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ def get_metric_info() -> MetricInfo:
5959
def evaluate_invocations(
6060
self,
6161
actual_invocations: list[Invocation],
62-
expected_invocations: list[Invocation],
62+
expected_invocations: Optional[list[Invocation]],
6363
) -> EvaluationResult:
64+
if expected_invocations is None:
65+
raise ValueError("expected_invocations is required for this metric.")
66+
6467
total_score = 0.0
6568
num_invocations = 0
6669
per_invocation_results = []

src/google/adk/evaluation/final_response_match_v2.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,11 @@ def __init__(
147147
self,
148148
eval_metric: EvalMetric,
149149
):
150-
super().__init__(eval_metric, FinalResponseMatchV2Evaluator.criterion_type)
150+
super().__init__(
151+
eval_metric,
152+
FinalResponseMatchV2Evaluator.criterion_type,
153+
expected_invocations_required=True,
154+
)
151155
self._auto_rater_prompt_template = _FINAL_RESPONSE_MATCH_V2_PROMPT
152156

153157
@staticmethod
@@ -166,8 +170,13 @@ def get_metric_info() -> MetricInfo:
166170

167171
@override
168172
def format_auto_rater_prompt(
169-
self, actual_invocation: Invocation, expected_invocation: Invocation
173+
self,
174+
actual_invocation: Invocation,
175+
expected_invocation: Optional[Invocation],
170176
) -> str:
177+
if expected_invocation is None:
178+
raise ValueError("expected_invocation is required for this metric.")
179+
171180
reference = get_text_from_content(expected_invocation.final_response)
172181
response = get_text_from_content(actual_invocation.final_response)
173182
user_prompt = get_text_from_content(expected_invocation.user_content)

src/google/adk/evaluation/hallucinations_v1.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,8 @@ def _create_context_for_step(
395395
},
396396
{
397397
"name": "get_weather",
398-
"description": '''Gets the weather of the given place at the given time.
398+
"description": '''Gets the weather of the given place at the given
399+
time.
399400
400401
Args:
401402
location: The location for which to retrieve weather information.
@@ -408,7 +409,8 @@ def _create_context_for_step(
408409
"type": "object",
409410
"properties": {
410411
"location": {
411-
"description": "The location for which to retrieve weather information.",
412+
"description": "The location for which to retrieve weather
413+
information.",
412414
"type": "string"
413415
},
414416
"time": {
@@ -711,8 +713,15 @@ def _aggregate_invocation_results(
711713
async def evaluate_invocations(
712714
self,
713715
actual_invocations: list[Invocation],
714-
expected_invocations: list[Invocation],
716+
expected_invocations: Optional[list[Invocation]],
715717
) -> EvaluationResult:
718+
# expected_invocations are not required by the metric and if they are not
719+
# supplied, we provide an a list of None to rest of the code.
720+
expected_invocations = (
721+
[None] * len(actual_invocations)
722+
if expected_invocations is None
723+
else expected_invocations
724+
)
716725
per_invocation_results = []
717726
for actual, expected in zip(actual_invocations, expected_invocations):
718727
step_evaluations = self._get_steps_to_evaluate(actual)

src/google/adk/evaluation/llm_as_judge.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ class LlmAsJudge(Evaluator):
6060
"""
6161

6262
def __init__(
63-
self, eval_metric: EvalMetric, criterion_type: type[BaseCriterion]
63+
self,
64+
eval_metric: EvalMetric,
65+
criterion_type: type[BaseCriterion],
66+
expected_invocations_required=False,
6467
):
6568
self._eval_metric = eval_metric
69+
self._expected_invocations_required = expected_invocations_required
6670

6771
expected_criterion_type_error = ValueError(
6872
f"`{eval_metric.metric_name}` metric expects a criterion of type"
@@ -84,7 +88,7 @@ def __init__(
8488

8589
@abstractmethod
8690
def format_auto_rater_prompt(
87-
self, actual: Invocation, expected: Invocation
91+
self, actual: Invocation, expected: Optional[Invocation]
8892
) -> str:
8993
"""Formats the auto-rater prompt to evaluate the given invocation."""
9094

@@ -112,8 +116,19 @@ def aggregate_invocation_results(
112116
async def evaluate_invocations(
113117
self,
114118
actual_invocations: list[Invocation],
115-
expected_invocations: list[Invocation],
119+
expected_invocations: Optional[list[Invocation]],
116120
) -> EvaluationResult:
121+
if self._expected_invocations_required and expected_invocations is None:
122+
raise ValueError("expected_invocations is needed by this metric.")
123+
124+
# If expected_invocation are not required by the metric and if they are not
125+
# supplied, we provide an a list of None.
126+
expected_invocations = (
127+
[None] * len(actual_invocations)
128+
if expected_invocations is None
129+
else expected_invocations
130+
)
131+
117132
per_invocation_results = []
118133
for actual, expected in zip(actual_invocations, expected_invocations):
119134
auto_rater_prompt = self.format_auto_rater_prompt(actual, expected)

src/google/adk/evaluation/local_eval_service.py

Lines changed: 42 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from typing import Optional
2323
import uuid
2424

25-
from google.genai.types import Content
26-
from google.genai.types import Part
2725
from typing_extensions import override
2826

2927
from ..agents.base_agent import BaseAgent
@@ -51,6 +49,7 @@
5149
from .evaluation_generator import EvaluationGenerator
5250
from .evaluator import EvalStatus
5351
from .evaluator import EvaluationResult
52+
from .evaluator import PerInvocationResult
5453
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
5554
from .metric_evaluator_registry import MetricEvaluatorRegistry
5655
from .user_simulator_provider import UserSimulatorProvider
@@ -222,69 +221,51 @@ async def _evaluate_single_inference_result(
222221
else 'test_user_id'
223222
)
224223

225-
if eval_case.conversation_scenario:
226-
logger.warning(
227-
'Skipping evaluation of variable-length conversation scenario in eval'
228-
' set/case %s/%s.',
229-
inference_result.eval_set_id,
230-
inference_result.eval_case_id,
231-
)
232-
for actual_invocation in inference_result.inferences:
233-
eval_metric_result_per_invocation.append(
234-
EvalMetricResultPerInvocation(
235-
actual_invocation=actual_invocation,
236-
expected_invocation=Invocation(
237-
user_content=actual_invocation.user_content,
238-
final_response=Content(
239-
parts=[Part(text='N/A')], role='model'
240-
),
241-
),
242-
)
243-
)
244-
eval_case_result = EvalCaseResult(
245-
eval_set_file=inference_result.eval_set_id,
246-
eval_set_id=inference_result.eval_set_id,
247-
eval_id=inference_result.eval_case_id,
248-
final_eval_status=EvalStatus.NOT_EVALUATED,
249-
overall_eval_metric_results=overall_eval_metric_results,
250-
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
251-
session_id=inference_result.session_id,
252-
session_details=await self._session_service.get_session(
253-
app_name=inference_result.app_name,
254-
user_id=user_id,
255-
session_id=inference_result.session_id,
256-
),
257-
user_id=user_id,
258-
)
259-
return (inference_result, eval_case_result)
260-
261-
if len(inference_result.inferences) != len(eval_case.conversation):
224+
if eval_case.conversation_scenario is None and len(
225+
inference_result.inferences
226+
) != len(eval_case.conversation):
262227
raise ValueError(
263228
'Inferences should match conversations in eval case. Found'
264229
f'{len(inference_result.inferences)} inferences '
265230
f'{len(eval_case.conversation)} conversations in eval cases.'
266231
)
267232

268233
# Pre-creating the EvalMetricResults entries for each invocation.
269-
for actual, expected in zip(
270-
inference_result.inferences, eval_case.conversation
271-
):
234+
for idx, actual in enumerate(inference_result.inferences):
272235
eval_metric_result_per_invocation.append(
273236
EvalMetricResultPerInvocation(
274237
actual_invocation=actual,
275-
expected_invocation=expected,
238+
expected_invocation=eval_case.conversation[idx]
239+
if eval_case.conversation
240+
else None,
276241
# We will fill this as we evaluate each metric per invocation.
277242
eval_metric_results=[],
278243
)
279244
)
280245

281246
for eval_metric in evaluate_config.eval_metrics:
282247
# Perform evaluation of the metric.
283-
evaluation_result = await self._evaluate_metric(
284-
eval_metric=eval_metric,
285-
actual_invocations=inference_result.inferences,
286-
expected_invocations=eval_case.conversation,
287-
)
248+
try:
249+
evaluation_result = await self._evaluate_metric(
250+
eval_metric=eval_metric,
251+
actual_invocations=inference_result.inferences,
252+
expected_invocations=eval_case.conversation,
253+
)
254+
except Exception as e:
255+
# We intentionally catch the Exception as we don't want failures to
256+
# affect other metric evaluation.
257+
logger.error(
258+
"Metric evaluation failed for metric `%s` for eval case id '%s'"
259+
' with following error `%s`',
260+
eval_metric.metric_name,
261+
eval_case.eval_id,
262+
e,
263+
exc_info=True,
264+
)
265+
# We use an empty result.
266+
evaluation_result = EvaluationResult(
267+
overall_eval_status=EvalStatus.NOT_EVALUATED
268+
)
288269

289270
# Track overall scrore across all invocations.
290271
eval_metric_result_details = EvalMetricResultDetails(
@@ -299,8 +280,10 @@ async def _evaluate_single_inference_result(
299280
)
300281
)
301282

302-
if len(evaluation_result.per_invocation_results) != len(
303-
eval_metric_result_per_invocation
283+
if (
284+
evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
285+
and len(evaluation_result.per_invocation_results)
286+
!= len(eval_metric_result_per_invocation)
304287
):
305288
raise ValueError(
306289
'Eval metric should return results for each invocation. Found '
@@ -309,10 +292,14 @@ async def _evaluate_single_inference_result(
309292
)
310293

311294
# Track score across individual invocations.
312-
for invocation_result, invocation in zip(
313-
evaluation_result.per_invocation_results,
314-
eval_metric_result_per_invocation,
315-
):
295+
for idx, invocation in enumerate(eval_metric_result_per_invocation):
296+
invocation_result = (
297+
evaluation_result.per_invocation_results[idx]
298+
if evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
299+
else PerInvocationResult(
300+
actual_invocation=invocation.actual_invocation
301+
)
302+
)
316303
eval_metric_result_details = EvalMetricResultDetails(
317304
rubric_scores=invocation_result.rubric_scores
318305
)
@@ -351,7 +338,7 @@ async def _evaluate_metric(
351338
self,
352339
eval_metric: EvalMetric,
353340
actual_invocations: list[Invocation],
354-
expected_invocations: list[Invocation],
341+
expected_invocations: Optional[list[Invocation]],
355342
) -> EvaluationResult:
356343
"""Returns EvaluationResult obtained from evaluating a metric using an Evaluator."""
357344

src/google/adk/evaluation/response_evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_metric_info(metric_name: str) -> MetricInfo:
100100
def evaluate_invocations(
101101
self,
102102
actual_invocations: list[Invocation],
103-
expected_invocations: list[Invocation],
103+
expected_invocations: Optional[list[Invocation]],
104104
) -> EvaluationResult:
105105
# If the metric is response_match_score, just use the RougeEvaluator.
106106
if self._metric_name == PrebuiltMetrics.RESPONSE_MATCH_SCORE.value:
@@ -112,5 +112,7 @@ def evaluate_invocations(
112112
)
113113

114114
return _VertexAiEvalFacade(
115-
threshold=self._threshold, metric_name=self._metric_name
115+
threshold=self._threshold,
116+
metric_name=self._metric_name,
117+
expected_invocations_required=True,
116118
).evaluate_invocations(actual_invocations, expected_invocations)

src/google/adk/evaluation/rubric_based_final_response_quality_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
from typing import ClassVar
19+
from typing import Optional
1920

2021
from typing_extensions import override
2122

@@ -281,7 +282,7 @@ def get_metric_info() -> MetricInfo:
281282

282283
@override
283284
def format_auto_rater_prompt(
284-
self, actual_invocation: Invocation, _: Invocation
285+
self, actual_invocation: Invocation, _: Optional[Invocation]
285286
) -> str:
286287
"""Returns the autorater prompt."""
287288

0 commit comments

Comments
 (0)