Skip to content

Commit 96284d7

Browse files
anistarktheSalt
andauthored
Devpod cn/main (#2309)
## Issue Link / Problem Description <!-- Link to related issue or describe the problem this PR solves --> - contd #1973 --------- Co-authored-by: Yin Liang <yinliang@devpod.cn>
1 parent fec8fea commit 96284d7

File tree

6 files changed

+194
-0
lines changed

6 files changed

+194
-0
lines changed

docs/concepts/metrics/available_metrics/context_precision.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,36 @@ Output
162162
```
163163
0.9999999999
164164
```
165+
166+
## ID Based Context Precision
167+
168+
IDBasedContextPrecision provides a direct and efficient way to measure precision by comparing the IDs of retrieved contexts with reference context IDs. This metric is particularly useful when you have a unique ID system for your documents and want to evaluate retrieval performance without comparing the actual content.
169+
170+
The metric computes precision using retrieved_context_ids and reference_context_ids, with values ranging between 0 and 1. Higher values indicate better performance. It works with both string and integer IDs.
171+
172+
The formula for calculating ID-based context precision is as follows:
173+
174+
$$ \text{ID-Based Context Precision} = \frac{\text{Number of retrieved context IDs found in reference context IDs}}{\text{Total number of retrieved context IDs}} $$
175+
176+
### Example
177+
178+
```python
179+
from ragas import SingleTurnSample
180+
from ragas.metrics import IDBasedContextPrecision
181+
182+
sample = SingleTurnSample(
183+
retrieved_context_ids=["doc_1", "doc_2", "doc_3", "doc_4"],
184+
reference_context_ids=["doc_1", "doc_4", "doc_5", "doc_6"]
185+
)
186+
187+
id_precision = IDBasedContextPrecision()
188+
await id_precision.single_turn_ascore(sample)
189+
190+
```
191+
192+
Output
193+
```
194+
0.5
195+
```
196+
197+
In this example, out of the 4 retrieved context IDs, only 2 ("doc_1" and "doc_4") are found in the reference context IDs, resulting in a precision score of 0.5 or 50%.

docs/concepts/metrics/available_metrics/context_recall.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,36 @@ await context_recall.single_turn_ascore(sample)
6969
Output
7070
```
7171
0.5
72+
```
73+
74+
## ID BasedContext Recall
75+
76+
ID Based Context Recall
77+
IDBasedContextRecall provides a direct and efficient way to measure recall by comparing the IDs of retrieved contexts with reference context IDs. This metric is particularly useful when you have a unique ID system for your documents and want to evaluate retrieval performance without comparing the actual content.
78+
79+
The metric computes recall using retrieved_context_ids and reference_context_ids, with values ranging between 0 and 1. Higher values indicate better performance. It works with both string and integer IDs.
80+
81+
The formula for calculating ID-based context recall is as follows:
82+
83+
$$ \text{ID-Based Context Recall} = \frac{\text{Number of reference context IDs found in retrieved context IDs}}{\text{Total number of reference context IDs}} $$
84+
85+
### Example
86+
87+
```python
88+
89+
from ragas.dataset_schema import SingleTurnSample
90+
from ragas.metrics import IDBasedContextRecall
91+
92+
sample = SingleTurnSample(
93+
retrieved_context_ids=["doc_1", "doc_2", "doc_3"],
94+
reference_context_ids=["doc_1", "doc_4", "doc_5", "doc_6"]
95+
)
96+
97+
id_recall = IDBasedContextRecall()
98+
await id_recall.single_turn_ascore(sample)
99+
```
100+
101+
Output
102+
```
103+
0.25
72104
```

src/ragas/dataset_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class SingleTurnSample(BaseSample):
6363
List of contexts retrieved for the query.
6464
reference_contexts : Optional[List[str]]
6565
List of reference contexts for the query.
66+
retrieved_context_ids : Optional[List[Union[str, int]]]
67+
List of IDs for retrieved contexts.
68+
reference_context_ids : Optional[List[Union[str, int]]]
69+
List of IDs for reference contexts.
6670
response : Optional[str]
6771
The generated response for the query.
6872
multi_responses : Optional[List[str]]
@@ -76,6 +80,8 @@ class SingleTurnSample(BaseSample):
7680
user_input: t.Optional[str] = None
7781
retrieved_contexts: t.Optional[t.List[str]] = None
7882
reference_contexts: t.Optional[t.List[str]] = None
83+
retrieved_context_ids: t.Optional[t.List[t.Union[str, int]]] = None
84+
reference_context_ids: t.Optional[t.List[t.Union[str, int]]] = None
7985
response: t.Optional[str] = None
8086
multi_responses: t.Optional[t.List[str]] = None
8187
reference: t.Optional[str] = None

src/ragas/metrics/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
from ragas.metrics._context_precision import (
2020
ContextPrecision,
2121
ContextUtilization,
22+
IDBasedContextPrecision,
2223
LLMContextPrecisionWithoutReference,
2324
LLMContextPrecisionWithReference,
2425
NonLLMContextPrecisionWithReference,
2526
context_precision,
2627
)
2728
from ragas.metrics._context_recall import (
2829
ContextRecall,
30+
IDBasedContextRecall,
2931
LLMContextRecall,
3032
NonLLMContextRecall,
3133
context_recall,
@@ -130,8 +132,10 @@
130132
"LLMContextPrecisionWithoutReference",
131133
"NonLLMContextPrecisionWithReference",
132134
"LLMContextPrecisionWithoutReference",
135+
"IDBasedContextPrecision",
133136
"LLMContextRecall",
134137
"NonLLMContextRecall",
138+
"IDBasedContextRecall",
135139
"FactualCorrectness",
136140
"InstanceRubrics",
137141
"NonLLMStringSimilarity",

src/ragas/metrics/_context_precision.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,66 @@ def _calculate_average_precision(self, verdict_list: t.List[int]) -> float:
248248
return score
249249

250250

251+
@dataclass
252+
class IDBasedContextPrecision(SingleTurnMetric):
253+
"""
254+
Calculates context precision by directly comparing retrieved context IDs with reference context IDs.
255+
The score represents what proportion of the retrieved context IDs are actually relevant (present in reference).
256+
257+
This metric works with both string and integer IDs.
258+
259+
Attributes
260+
----------
261+
name : str
262+
Name of the metric
263+
"""
264+
265+
name: str = "id_based_context_precision"
266+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
267+
default_factory=lambda: {
268+
MetricType.SINGLE_TURN: {
269+
"retrieved_context_ids",
270+
"reference_context_ids",
271+
}
272+
}
273+
)
274+
output_type: MetricOutputType = MetricOutputType.CONTINUOUS
275+
276+
def init(self, run_config: RunConfig) -> None: ...
277+
278+
async def _single_turn_ascore(
279+
self, sample: SingleTurnSample, callbacks: Callbacks
280+
) -> float:
281+
retrieved_context_ids = sample.retrieved_context_ids
282+
reference_context_ids = sample.reference_context_ids
283+
assert retrieved_context_ids is not None, "retrieved_context_ids is empty"
284+
assert reference_context_ids is not None, "reference_context_ids is empty"
285+
286+
# Convert all IDs to strings to ensure consistent comparison
287+
retrieved_ids_set = set(str(id) for id in retrieved_context_ids)
288+
reference_ids_set = set(str(id) for id in reference_context_ids)
289+
290+
# Calculate precision score
291+
total_retrieved = len(retrieved_ids_set)
292+
if total_retrieved == 0:
293+
logger.warning(
294+
"No retrieved context IDs provided, cannot calculate precision."
295+
)
296+
return np.nan
297+
298+
# Count how many retrieved IDs match reference IDs
299+
hits = sum(
300+
1 for ret_id in retrieved_ids_set if str(ret_id) in reference_ids_set
301+
)
302+
303+
# For precision, we calculate: relevant retrieved / total retrieved
304+
score = hits / total_retrieved
305+
return score
306+
307+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
308+
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)
309+
310+
251311
@dataclass
252312
class ContextPrecision(LLMContextPrecisionWithReference):
253313
name: str = "context_precision"

src/ragas/metrics/_context_recall.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,4 +235,63 @@ def _compute_score(self, verdict_list: t.List[float]) -> float:
235235
return score
236236

237237

238+
@dataclass
239+
class IDBasedContextRecall(SingleTurnMetric):
240+
"""
241+
Calculates context recall by directly comparing retrieved context IDs with reference context IDs.
242+
The score represents what proportion of the reference IDs were successfully retrieved.
243+
244+
This metric works with both string and integer IDs.
245+
246+
Attributes
247+
----------
248+
name : str
249+
Name of the metric
250+
"""
251+
252+
name: str = "id_based_context_recall"
253+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
254+
default_factory=lambda: {
255+
MetricType.SINGLE_TURN: {
256+
"retrieved_context_ids",
257+
"reference_context_ids",
258+
}
259+
}
260+
)
261+
output_type: MetricOutputType = MetricOutputType.CONTINUOUS
262+
263+
def init(self, run_config: RunConfig) -> None: ...
264+
265+
async def _single_turn_ascore(
266+
self, sample: SingleTurnSample, callbacks: Callbacks
267+
) -> float:
268+
retrieved_context_ids = sample.retrieved_context_ids
269+
reference_context_ids = sample.reference_context_ids
270+
assert retrieved_context_ids is not None, "retrieved_context_ids is empty"
271+
assert reference_context_ids is not None, "reference_context_ids is empty"
272+
273+
# Convert all IDs to strings to ensure consistent comparison
274+
retrieved_ids_set = set(str(id) for id in retrieved_context_ids)
275+
reference_ids_set = set(str(id) for id in reference_context_ids)
276+
277+
# Calculate how many reference IDs appear in retrieved IDs
278+
hits = sum(
279+
1 for ref_id in reference_ids_set if str(ref_id) in retrieved_ids_set
280+
)
281+
282+
# Calculate recall score
283+
total_refs = len(reference_ids_set)
284+
score = hits / total_refs if total_refs > 0 else np.nan
285+
286+
if np.isnan(score):
287+
logger.warning(
288+
"No reference context IDs provided, cannot calculate recall."
289+
)
290+
291+
return score
292+
293+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
294+
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)
295+
296+
238297
context_recall = ContextRecall()

0 commit comments

Comments
 (0)