Skip to content

Commit 9ebf8b3

Browse files
committed
Merge branch 'fix/optional-text-requirements' into fix/data-types-imports
2 parents 949eddf + 5bc2355 commit 9ebf8b3

File tree

5 files changed

+21
-19
lines changed

5 files changed

+21
-19
lines changed

dreadnode/scorers/classification.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dreadnode.meta import Config
44
from dreadnode.metric import Metric
55
from dreadnode.scorers import Scorer
6-
from dreadnode.util import clean_str, warn_at_user_stacklevel
6+
from dreadnode.util import clean_str, generate_import_error_msg, warn_at_user_stacklevel
77

88
# Global cache for pipelines
99
g_transformer_pipeline_cache: dict[str, t.Any] = {}
@@ -30,9 +30,7 @@ def zero_shot_classification(
3030
model_name: The name of the zero-shot model from Hugging Face Hub.
3131
name: Name of the scorer.
3232
"""
33-
transformers_error_msg = (
34-
"Transformers dependency is not installed. Install with: pip install transformers"
35-
)
33+
transformers_error_msg = generate_import_error_msg("transformers", "training")
3634

3735
try:
3836
from transformers import pipeline # type: ignore[import-not-found]

dreadnode/scorers/pii.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dreadnode.metric import Metric
55
from dreadnode.scorers import Scorer
66
from dreadnode.scorers.contains import contains
7-
from dreadnode.util import warn_at_user_stacklevel
7+
from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel
88

99
if t.TYPE_CHECKING:
1010
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
@@ -103,13 +103,10 @@ def detect_pii_with_presidio(
103103
invert: Invert the score (1.0 for no PII, 0.0 for PII detected).
104104
name: Name of the scorer.
105105
"""
106-
presidio_import_error_msg = (
107-
"Presidio dependencies are not installed. "
108-
"Install with: pip install presidio-analyzer presidio-anonymizer 'spacy[en_core_web_lg]'"
109-
)
106+
presidio_import_error_msg = generate_import_error_msg("presidio-analyzer", "text")
110107

111108
try:
112-
import presidio_analyzer # type: ignore[import-not-found,unused-ignore]
109+
import presidio_analyzer # type: ignore[import-not-found]
113110
except ImportError:
114111
warn_at_user_stacklevel(presidio_import_error_msg, UserWarning)
115112

dreadnode/scorers/sentiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dreadnode.meta import Config
77
from dreadnode.metric import Metric
88
from dreadnode.scorers.base import Scorer
9-
from dreadnode.util import warn_at_user_stacklevel
9+
from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel
1010

1111
Sentiment = t.Literal["positive", "negative", "neutral"]
1212

@@ -29,7 +29,7 @@ def sentiment(
2929
target: The desired sentiment to score against.
3030
name: Name of the scorer.
3131
"""
32-
textblob_import_error_msg = "TextBlob dependency is not installed. Install with: pip install textblob && python -m textblob.download_corpora"
32+
textblob_import_error_msg = generate_import_error_msg("textblob", "text")
3333

3434
try:
3535
from textblob import TextBlob # type: ignore[import-not-found]

dreadnode/scorers/similarity.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dreadnode.metric import Metric
66
from dreadnode.scorers.base import Scorer
77
from dreadnode.scorers.util import cosine_similarity
8-
from dreadnode.util import warn_at_user_stacklevel
8+
from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel
99

1010
if t.TYPE_CHECKING:
1111
from sentence_transformers import ( # type: ignore[import-not-found]
@@ -88,7 +88,7 @@ def similarity_with_rapidfuzz(
8888
score_cutoff: Optional score cutoff below which to return 0.0.
8989
name: Name of the scorer.
9090
"""
91-
rapidfuzz_import_error_msg = "RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz or dreadnode[text]"
91+
rapidfuzz_import_error_msg = generate_import_error_msg("rapidfuzz", "text")
9292
try:
9393
from rapidfuzz import fuzz, utils # type: ignore[import-not-found]
9494
except ImportError:
@@ -182,7 +182,7 @@ def string_distance(
182182
normalize: Normalize distances and convert to similarity scores.
183183
name: Name of the scorer.
184184
"""
185-
rapidfuzz_import_error_msg = "RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz or dreadnode[text]"
185+
rapidfuzz_import_error_msg = generate_import_error_msg("rapidfuzz", "text")
186186

187187
try:
188188
from rapidfuzz import distance # type: ignore[import-not-found]
@@ -251,7 +251,7 @@ def similarity_with_tf_idf(reference: str, *, name: str = "similarity") -> "Scor
251251
reference: The reference text (e.g., expected output).
252252
name: Name of the scorer.
253253
"""
254-
sklearn_import_error_msg = "scikit-learn dependency is not installed. Please install it with: pip install scikit-learn or dreadnode[text]"
254+
sklearn_import_error_msg = generate_import_error_msg("scikit-learn", "text")
255255

256256
try:
257257
from sklearn.feature_extraction.text import ( # type: ignore[import-not-found]
@@ -303,7 +303,7 @@ def similarity_with_sentence_transformers(
303303
model_name: The name of the sentence-transformer model to use.
304304
name: Name of the scorer.
305305
"""
306-
sentence_transformers_error_msg = "Sentence transformers dependency is not installed. Please install it with: pip install sentence-transformers or dreadnode[training]"
306+
sentence_transformers_error_msg = generate_import_error_msg("sentence-transformers", "training")
307307

308308
try:
309309
from sentence_transformers import ( # type: ignore[import-not-found]
@@ -367,7 +367,7 @@ def similarity_with_litellm(
367367
or self-hosted models.
368368
name: Name of the scorer.
369369
"""
370-
litellm_import_error_msg = "litellm dependency is not installed. Please install it with: pip install litellm or dreadnode[text]"
370+
litellm_import_error_msg = generate_import_error_msg("litellm", "text")
371371
try:
372372
import litellm
373373
except ImportError:
@@ -430,7 +430,7 @@ def bleu(
430430
weights: Weights for unigram, bigram, etc. Must sum to 1.
431431
name: Name of the scorer.
432432
"""
433-
nltk_import_error_msg = "NLTK dependency is not installed. Install with: pip install nltk && python -m nltk.downloader punkt"
433+
nltk_import_error_msg = generate_import_error_msg("nltk", "text")
434434

435435
try:
436436
import nltk # type: ignore[import-not-found]

dreadnode/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def format_dict(data: dict[str, t.Any], max_length: int = 80) -> str:
154154
return f"{{{formatted}}}"
155155

156156

157+
def generate_import_error_msg(package_name: str, extras_name: str) -> str:
158+
return (
159+
f"Missing required package '{package_name}'. "
160+
f"Please install it with: pip install {package_name} or dreadnode[{extras_name}]"
161+
)
162+
163+
157164
# Types
158165

159166

0 commit comments

Comments
 (0)