Skip to content

Commit a54b1c9

Browse files
authored
Metrics migration, migrate rouge + answer relevance (#2335)
1 parent ef7892c commit a54b1c9

File tree

10 files changed

+963
-0
lines changed

10 files changed

+963
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Collections of metrics using modern component architecture."""
2+
3+
from ragas.metrics.collections._answer_relevancy import AnswerRelevancy
4+
from ragas.metrics.collections._rouge_score import RougeScore
5+
from ragas.metrics.collections.base import BaseMetric
6+
7+
__all__ = [
8+
"AnswerRelevancy", # Class-based answer relevancy
9+
"RougeScore", # Class-based rouge score
10+
"BaseMetric", # Base class for creating new v2 metrics
11+
]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Answer Relevancy metric v2 - Class-based implementation with modern components."""
2+
3+
import typing as t
4+
5+
import numpy as np
6+
from pydantic import BaseModel
7+
8+
from ragas.metrics.collections.base import BaseMetric
9+
from ragas.metrics.result import MetricResult
10+
from ragas.prompt.metrics.answer_relevance import answer_relevancy_prompt
11+
12+
if t.TYPE_CHECKING:
13+
from ragas.embeddings.base import BaseRagasEmbedding
14+
from ragas.llms.base import InstructorBaseRagasLLM
15+
16+
17+
class AnswerRelevanceOutput(BaseModel):
18+
"""Structured output for answer relevance question generation."""
19+
20+
question: str
21+
noncommittal: int
22+
23+
24+
class AnswerRelevancy(BaseMetric):
25+
"""
26+
Evaluate answer relevancy by generating questions from the response and comparing to original question.
27+
28+
This implementation uses modern instructor LLMs with structured output and modern embeddings.
29+
Only supports modern components - legacy wrappers are rejected with clear error messages.
30+
31+
Usage:
32+
>>> import instructor
33+
>>> from openai import AsyncOpenAI
34+
>>> from ragas.llms.base import instructor_llm_factory
35+
>>> from ragas.embeddings.base import embedding_factory
36+
>>> from ragas.metrics.collections import AnswerRelevancy
37+
>>>
38+
>>> # Setup dependencies
39+
>>> client = AsyncOpenAI()
40+
>>> llm = instructor_llm_factory("openai", client=client, model="gpt-4o-mini")
41+
>>> embeddings = embedding_factory("openai", model="text-embedding-ada-002", client=client, interface="modern")
42+
>>>
43+
>>> # Create metric instance
44+
>>> metric = AnswerRelevancy(llm=llm, embeddings=embeddings, strictness=3)
45+
>>>
46+
>>> # Single evaluation
47+
>>> result = await metric.ascore(
48+
... user_input="What is the capital of France?",
49+
... response="Paris is the capital of France."
50+
... )
51+
>>> print(f"Score: {result.value}")
52+
>>>
53+
>>> # Batch evaluation
54+
>>> results = await metric.abatch_score([
55+
... {"user_input": "Q1", "response": "A1"},
56+
... {"user_input": "Q2", "response": "A2"},
57+
... ])
58+
59+
Attributes:
60+
llm: Modern instructor-based LLM for question generation
61+
embeddings: Modern embeddings model with embed_text() and embed_texts() methods
62+
name: The metric name
63+
strictness: Number of questions to generate per answer (3-5 recommended)
64+
allowed_values: Score range (0.0 to 1.0)
65+
"""
66+
67+
# Type hints for linter (attributes are set in __init__)
68+
llm: "InstructorBaseRagasLLM"
69+
embeddings: "BaseRagasEmbedding"
70+
71+
def __init__(
72+
self,
73+
llm: "InstructorBaseRagasLLM",
74+
embeddings: "BaseRagasEmbedding",
75+
name: str = "answer_relevancy",
76+
strictness: int = 3,
77+
**kwargs,
78+
):
79+
"""Initialize AnswerRelevancy metric with required components."""
80+
# Set attributes explicitly before calling super()
81+
self.llm = llm
82+
self.embeddings = embeddings
83+
self.strictness = strictness
84+
85+
# Call super() for validation (without passing llm/embeddings in kwargs)
86+
super().__init__(name=name, **kwargs)
87+
88+
async def ascore(self, user_input: str, response: str) -> MetricResult:
89+
"""
90+
Calculate answer relevancy score asynchronously.
91+
92+
Components are guaranteed to be validated and non-None by the base class.
93+
94+
Args:
95+
user_input: The original question
96+
response: The response to evaluate
97+
98+
Returns:
99+
MetricResult with relevancy score (0.0-1.0)
100+
"""
101+
prompt = answer_relevancy_prompt(response)
102+
103+
generated_questions = []
104+
noncommittal_flags = []
105+
106+
for _ in range(self.strictness):
107+
result = await self.llm.agenerate(prompt, AnswerRelevanceOutput)
108+
109+
if result.question:
110+
generated_questions.append(result.question)
111+
noncommittal_flags.append(result.noncommittal)
112+
113+
if not generated_questions:
114+
return MetricResult(value=0.0)
115+
116+
all_noncommittal = np.all(noncommittal_flags)
117+
118+
question_vec = np.asarray(self.embeddings.embed_text(user_input)).reshape(1, -1)
119+
gen_question_vec = np.asarray(
120+
self.embeddings.embed_texts(generated_questions)
121+
).reshape(len(generated_questions), -1)
122+
123+
norm = np.linalg.norm(gen_question_vec, axis=1) * np.linalg.norm(
124+
question_vec, axis=1
125+
)
126+
cosine_sim = (
127+
np.dot(gen_question_vec, question_vec.T).reshape(
128+
-1,
129+
)
130+
/ norm
131+
)
132+
133+
score = cosine_sim.mean() * int(not all_noncommittal)
134+
135+
return MetricResult(value=float(score))
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Rouge Score metric v2 - Class-based implementation with automatic validation."""
2+
3+
import typing as t
4+
5+
from ragas.metrics.collections.base import BaseMetric
6+
from ragas.metrics.result import MetricResult
7+
8+
9+
class RougeScore(BaseMetric):
10+
"""
11+
Calculate ROUGE score between reference and response texts.
12+
13+
This implementation provides automatic validation and pure async design
14+
without requiring LLM or embedding components.
15+
16+
Usage:
17+
>>> from ragas.metrics.collections import RougeScore
18+
>>>
19+
>>> # Create metric instance (no LLM/embeddings needed)
20+
>>> metric = RougeScore(rouge_type="rougeL", mode="fmeasure")
21+
>>>
22+
>>> # Single evaluation
23+
>>> result = await metric.ascore(
24+
... reference="The capital of France is Paris.",
25+
... response="Paris is the capital of France."
26+
... )
27+
>>> print(f"Score: {result.value}")
28+
>>>
29+
>>> # Batch evaluation
30+
>>> results = await metric.abatch_score([
31+
... {"reference": "Text 1", "response": "Response 1"},
32+
... {"reference": "Text 2", "response": "Response 2"},
33+
... ])
34+
35+
Attributes:
36+
name: The metric name
37+
rouge_type: Type of ROUGE metric ("rouge1" for unigrams, "rougeL" for LCS)
38+
mode: Scoring mode ("fmeasure", "precision", or "recall")
39+
allowed_values: Score range (0.0 to 1.0)
40+
41+
Note: This metric doesn't define llm or embeddings fields, so no validation is performed.
42+
"""
43+
44+
def __init__(
45+
self,
46+
name: str = "rouge_score",
47+
rouge_type: t.Literal["rouge1", "rougeL"] = "rougeL",
48+
mode: t.Literal["fmeasure", "precision", "recall"] = "fmeasure",
49+
**kwargs,
50+
):
51+
"""Initialize RougeScore metric."""
52+
super().__init__(name=name, **kwargs)
53+
self.rouge_type = rouge_type
54+
self.mode = mode
55+
56+
async def ascore(
57+
self,
58+
reference: str,
59+
response: str,
60+
) -> MetricResult:
61+
"""
62+
Calculate ROUGE score asynchronously.
63+
64+
Args:
65+
reference: The reference/ground truth text
66+
response: The response text to evaluate
67+
68+
Returns:
69+
MetricResult with ROUGE score (0.0-1.0)
70+
"""
71+
# Import and check dependencies
72+
try:
73+
from rouge_score import rouge_scorer
74+
except ImportError:
75+
raise ImportError(
76+
"rouge_score is required for ROUGE score calculation. "
77+
"Please install it using `pip install rouge_score`"
78+
)
79+
80+
# Calculate ROUGE score
81+
scorer = rouge_scorer.RougeScorer([self.rouge_type], use_stemmer=True)
82+
scores = scorer.score(reference, response)
83+
score_value = getattr(scores[self.rouge_type], self.mode)
84+
85+
return MetricResult(value=float(score_value))
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Base class for collections metrics with modern component validation."""
2+
3+
import asyncio
4+
import typing as t
5+
6+
from ragas.embeddings.base import BaseRagasEmbedding
7+
from ragas.llms.base import InstructorBaseRagasLLM
8+
from ragas.metrics.base import SimpleBaseMetric
9+
from ragas.metrics.result import MetricResult
10+
from ragas.metrics.validators import NumericValidator
11+
12+
13+
class BaseMetric(SimpleBaseMetric, NumericValidator):
14+
"""
15+
Base class for metrics collections with modern component validation.
16+
17+
This class inherits from SimpleBaseMetric and NumericValidator to provide:
18+
- All the base metric functionality (ascore, abatch_score, score, batch_score)
19+
- Numeric validation with configurable ranges
20+
- Modern LLM and embedding component validation (when defined by subclass)
21+
- Rejection of legacy wrappers with helpful error messages
22+
- Consistent error handling and type safety
23+
24+
Attributes:
25+
name: The metric name
26+
allowed_values: Score range for numeric validation (tuple of min, max)
27+
28+
Note: Subclasses define llm and/or embeddings fields only if they need them.
29+
The base classes handle all the core metric functionality - we just add modern component validation.
30+
"""
31+
32+
def __init__(
33+
self,
34+
name: str = "base_metric",
35+
allowed_values: t.Tuple[float, float] = (0.0, 1.0),
36+
**kwargs,
37+
):
38+
"""Initialize the base metric with validation."""
39+
super().__init__(name=name, allowed_values=allowed_values)
40+
41+
# Validate components only if the metric defines them
42+
# Check if this instance has these attributes after initialization
43+
if hasattr(self, "llm"):
44+
self._validate_llm()
45+
if hasattr(self, "embeddings"):
46+
self._validate_embeddings()
47+
48+
async def ascore(self, **kwargs) -> MetricResult:
49+
"""
50+
Default async scoring method - subclasses should override this.
51+
52+
This base implementation just returns a placeholder result.
53+
Subclasses should override this method with their specific logic.
54+
55+
The base class handles component validation in __post_init__.
56+
"""
57+
return MetricResult(
58+
value=0.0, reason="Base metric placeholder - override ascore() in subclass"
59+
)
60+
61+
def score(self, **kwargs) -> MetricResult:
62+
"""
63+
Synchronous scoring method that wraps ascore().
64+
65+
This is a convenience method for backward compatibility and sync usage.
66+
For better performance, prefer using ascore() directly in async contexts.
67+
68+
Returns:
69+
MetricResult object
70+
"""
71+
try:
72+
# Check if we're already in an async context
73+
asyncio.get_running_loop()
74+
# If we get here, there's already a running loop
75+
raise RuntimeError(
76+
"Cannot call sync score() from an async context. Use ascore() instead."
77+
)
78+
except RuntimeError as e:
79+
if "Use ascore() instead" in str(e):
80+
raise # Re-raise our custom error
81+
# No running loop found, safe to use asyncio.run()
82+
return asyncio.run(self.ascore(**kwargs))
83+
84+
def batch_score(
85+
self,
86+
inputs: t.List[t.Dict[str, t.Any]],
87+
) -> t.List[MetricResult]:
88+
"""
89+
Synchronous batch scoring that wraps abatch_score().
90+
91+
This is a convenience method for backward compatibility and sync usage.
92+
For better performance, prefer using abatch_score() directly in async contexts.
93+
94+
Args:
95+
inputs: List of input dictionaries for scoring
96+
97+
Returns:
98+
List of MetricResult objects
99+
"""
100+
try:
101+
# Check if we're already in an async context
102+
asyncio.get_running_loop()
103+
# If we get here, there's already a running loop
104+
raise RuntimeError(
105+
"Cannot call sync batch_score() from an async context. Use abatch_score() instead."
106+
)
107+
except RuntimeError as e:
108+
if "Use abatch_score() instead" in str(e):
109+
raise # Re-raise our custom error
110+
# No running loop found, safe to use asyncio.run()
111+
return asyncio.run(self.abatch_score(inputs))
112+
113+
def _validate_llm(self):
114+
"""Validate that a modern InstructorLLM is provided."""
115+
llm = getattr(self, "llm", None)
116+
117+
if not isinstance(llm, InstructorBaseRagasLLM):
118+
raise ValueError(
119+
f"Collections metrics only support modern InstructorLLM. Found: {type(llm).__name__}. "
120+
f"Use: instructor_llm_factory('openai', model='gpt-4o-mini', client=openai_client)"
121+
)
122+
123+
def _validate_embeddings(self):
124+
"""Validate that modern embeddings are provided."""
125+
embeddings = getattr(self, "embeddings", None)
126+
127+
if not isinstance(embeddings, BaseRagasEmbedding):
128+
raise ValueError(
129+
f"Collections metrics only support modern embeddings. Found: {type(embeddings).__name__}. "
130+
f"Use: embedding_factory('openai', model='text-embedding-ada-002', client=openai_client, interface='modern')"
131+
)

0 commit comments

Comments
 (0)