Skip to content

Commit e86105b

Browse files
authored
refactor: make embeddings optional in AnswerCorrectness when using pure factuality mode (#2414)
## Issue Link / Problem Description <!-- Link to related issue or describe the problem this PR solves --> - Fixes #2408
1 parent cffb1e9 commit e86105b

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

src/ragas/metrics/collections/_answer_correctness.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ class AnswerCorrectness(BaseMetric):
8888

8989
# Type hints for linter (attributes are set in __init__)
9090
llm: "InstructorBaseRagasLLM"
91-
embeddings: "BaseRagasEmbedding"
91+
embeddings: t.Optional["BaseRagasEmbedding"]
9292

9393
def __init__(
9494
self,
9595
llm: "InstructorBaseRagasLLM",
96-
embeddings: "BaseRagasEmbedding",
96+
embeddings: t.Optional["BaseRagasEmbedding"] = None,
9797
name: str = "answer_correctness",
9898
weights: List[float] = [0.75, 0.25],
9999
beta: float = 1.0,
@@ -104,9 +104,21 @@ def __init__(
104104
105105
Args:
106106
llm: Modern instructor-based LLM for statement generation and classification
107-
embeddings: Modern embeddings model for similarity calculation
107+
embeddings: Modern embeddings model for similarity calculation. Optional if similarity
108+
weight is 0 (pure factuality evaluation). Required if similarity weight > 0.
109+
name: The metric name
108110
weights: [factuality_weight, similarity_weight]. Must sum to > 0.
109111
beta: F-beta score parameter. β>1 favors recall, β<1 favors precision.
112+
113+
Raises:
114+
ValueError: If weights are invalid or embeddings are missing when needed for similarity scoring.
115+
116+
Examples:
117+
Pure factuality (no embeddings needed):
118+
>>> metric = AnswerCorrectness(llm=llm, weights=[1.0, 0.0])
119+
120+
Factuality + Similarity (embeddings required):
121+
>>> metric = AnswerCorrectness(llm=llm, embeddings=embeddings, weights=[0.75, 0.25])
110122
"""
111123
# Set attributes explicitly before calling super()
112124
self.llm = llm
@@ -124,6 +136,14 @@ def __init__(
124136
if not all([w >= 0 for w in weights]):
125137
raise ValueError("Weights must be non-negative")
126138

139+
# Validate embeddings availability when similarity weight > 0
140+
if weights[1] > 0 and embeddings is None:
141+
raise ValueError(
142+
"Embeddings are required for semantic similarity scoring. "
143+
"Either provide embeddings or set similarity weight to 0 (weights=[1.0, 0.0]) "
144+
"for pure factuality-only evaluation."
145+
)
146+
127147
# Validate beta
128148
if not isinstance(beta, float):
129149
raise ValueError(
@@ -133,6 +153,17 @@ def __init__(
133153
# Call super() for validation (without passing llm/embeddings in kwargs)
134154
super().__init__(name=name, **kwargs)
135155

156+
def _validate_embeddings(self) -> None:
157+
"""Override base validation to allow optional embeddings.
158+
159+
AnswerCorrectness metric allows embeddings to be None when using
160+
pure factuality evaluation (weights=[1.0, 0.0]). The main validation
161+
of embeddings availability happens in __init__ based on weights.
162+
"""
163+
# Only validate embeddings if similarity weight > 0
164+
# (validation logic already in __init__)
165+
pass
166+
136167
async def ascore(
137168
self, user_input: str, response: str, reference: str
138169
) -> MetricResult:

tests/e2e/metrics_migration/test_answer_correctness_migration.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,10 @@ def test_answer_correctness_parameter_validation(self):
339339
"""Test that v2 implementation properly validates parameters."""
340340
from unittest.mock import Mock
341341

342-
mock_llm = Mock()
342+
from ragas.llms.base import InstructorBaseRagasLLM
343+
344+
# Create proper mocks that inherit from the required base class
345+
mock_llm = Mock(spec=InstructorBaseRagasLLM)
343346
mock_embeddings = Mock()
344347

345348
# Test invalid weights
@@ -360,6 +363,15 @@ def test_answer_correctness_parameter_validation(self):
360363
with pytest.raises(ValueError, match="Beta must be a float"):
361364
AnswerCorrectness(llm=mock_llm, embeddings=mock_embeddings, beta="invalid") # type: ignore
362365

366+
# Test optional embeddings - should work with pure factuality (weight=0)
367+
metric = AnswerCorrectness(llm=mock_llm, weights=[1.0, 0.0])
368+
assert metric.embeddings is None
369+
print("✅ Optional embeddings working for pure factuality!")
370+
371+
# Test embeddings required when similarity weight > 0
372+
with pytest.raises(ValueError, match="Embeddings are required"):
373+
AnswerCorrectness(llm=mock_llm, embeddings=None, weights=[0.75, 0.25])
374+
363375
print("✅ Parameter validation working correctly!")
364376

365377
def test_answer_correctness_migration_requirements_documented(self):

0 commit comments

Comments
 (0)