Skip to content

Commit 8d42d4a

Browse files
authored
Merge pull request #160 from dreadnode/fix/optional-text-requirements
fix: text extras import errors
2 parents 5d99ab3 + 5bc2355 commit 8d42d4a

File tree

10 files changed

+108
-98
lines changed

10 files changed

+108
-98
lines changed

dreadnode/scorers/classification.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import typing as t
22

3-
from transformers import pipeline
4-
53
from dreadnode.meta import Config
64
from dreadnode.metric import Metric
75
from dreadnode.scorers import Scorer
8-
from dreadnode.util import clean_str, warn_at_user_stacklevel
6+
from dreadnode.util import clean_str, generate_import_error_msg, warn_at_user_stacklevel
97

108
# Global cache for pipelines
119
g_transformer_pipeline_cache: dict[str, t.Any] = {}
@@ -32,12 +30,10 @@ def zero_shot_classification(
3230
model_name: The name of the zero-shot model from Hugging Face Hub.
3331
name: Name of the scorer.
3432
"""
35-
transformers_error_msg = (
36-
"Transformers dependency is not installed. Install with: pip install transformers"
37-
)
33+
transformers_error_msg = generate_import_error_msg("transformers", "training")
3834

3935
try:
40-
pipeline("zero-shot-classification", model=model_name)
36+
from transformers import pipeline # type: ignore[import-not-found]
4137
except ImportError:
4238
warn_at_user_stacklevel(transformers_error_msg, UserWarning)
4339

dreadnode/scorers/pii.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import re
22
import typing as t
33

4-
from presidio_analyzer import AnalyzerEngine
5-
from presidio_analyzer.nlp_engine import NlpEngineProvider
6-
74
from dreadnode.metric import Metric
85
from dreadnode.scorers import Scorer
96
from dreadnode.scorers.contains import contains
10-
from dreadnode.util import warn_at_user_stacklevel
7+
from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel
118

129
if t.TYPE_CHECKING:
10+
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
11+
1312
from dreadnode.types import JsonDict
1413

1514

@@ -66,6 +65,9 @@ def _get_presidio_analyzer() -> "AnalyzerEngine":
6665
"""Lazily initializes and returns a singleton Presidio AnalyzerEngine instance."""
6766
global g_analyzer_engine # noqa: PLW0603
6867

68+
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
69+
from presidio_analyzer.nlp_engine import NlpEngineProvider # type: ignore[import-not-found]
70+
6971
if g_analyzer_engine is None:
7072
provider = NlpEngineProvider(
7173
nlp_configuration={
@@ -101,14 +103,11 @@ def detect_pii_with_presidio(
101103
invert: Invert the score (1.0 for no PII, 0.0 for PII detected).
102104
name: Name of the scorer.
103105
"""
104-
presidio_import_error_msg = (
105-
"Presidio dependencies are not installed. "
106-
"Install with: pip install presidio-analyzer presidio-anonymizer 'spacy[en_core_web_lg]'"
107-
)
106+
presidio_import_error_msg = generate_import_error_msg("presidio-analyzer", "text")
108107

109108
try:
110-
_get_presidio_analyzer()
111-
except (ImportError, OSError):
109+
import presidio_analyzer # type: ignore[import-not-found]
110+
except ImportError:
112111
warn_at_user_stacklevel(presidio_import_error_msg, UserWarning)
113112

114113
def disabled_evaluate(_: t.Any) -> Metric:

dreadnode/scorers/readability.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import typing as t
22

3-
import textstat # type: ignore[import-untyped]
4-
53
from dreadnode.metric import Metric
64
from dreadnode.scorers.base import Scorer
75
from dreadnode.util import warn_at_user_stacklevel
@@ -29,8 +27,8 @@ def readability(
2927
)
3028

3129
try:
32-
textstat.flesch_kincaid_grade("test")
33-
except (ImportError, AttributeError):
30+
import textstat # type: ignore[import-not-found]
31+
except ImportError:
3432
warn_at_user_stacklevel(textstat_import_error_msg, UserWarning)
3533

3634
def disabled_evaluate(_: t.Any) -> Metric:

dreadnode/scorers/sentiment.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import typing as t
33

44
import httpx
5-
from textblob import TextBlob # type: ignore[import-untyped]
65

76
from dreadnode.meta import Config
87
from dreadnode.metric import Metric
98
from dreadnode.scorers.base import Scorer
10-
from dreadnode.util import warn_at_user_stacklevel
9+
from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel
1110

1211
Sentiment = t.Literal["positive", "negative", "neutral"]
1312

@@ -30,11 +29,11 @@ def sentiment(
3029
target: The desired sentiment to score against.
3130
name: Name of the scorer.
3231
"""
33-
textblob_import_error_msg = "TextBlob dependency is not installed. Install with: pip install textblob && python -m textblob.download_corpora"
32+
textblob_import_error_msg = generate_import_error_msg("textblob", "text")
3433

3534
try:
36-
TextBlob("test").sentiment # noqa: B018
37-
except (ImportError, AttributeError):
35+
from textblob import TextBlob # type: ignore[import-not-found]
36+
except ImportError:
3837
warn_at_user_stacklevel(textblob_import_error_msg, UserWarning)
3938

4039
def disabled_evaluate(_: t.Any) -> Metric:

dreadnode/scorers/similarity.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
import typing as t
22
from difflib import SequenceMatcher
33

4-
import litellm
5-
import nltk # type: ignore[import-untyped]
6-
from nltk.tokenize import word_tokenize # type: ignore[import-untyped]
7-
from nltk.translate.bleu_score import sentence_bleu # type: ignore[import-untyped]
8-
from rapidfuzz import distance, fuzz, utils
9-
from sentence_transformers import SentenceTransformer, util
10-
from sklearn.feature_extraction.text import TfidfVectorizer # type: ignore[import-untyped]
11-
from sklearn.metrics.pairwise import ( # type: ignore # noqa: PGH003
12-
cosine_similarity as sklearn_cosine_similarity,
13-
)
14-
154
from dreadnode.meta import Config
165
from dreadnode.metric import Metric
176
from dreadnode.scorers.base import Scorer
187
from dreadnode.scorers.util import cosine_similarity
19-
from dreadnode.util import warn_at_user_stacklevel
8+
from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel
9+
10+
if t.TYPE_CHECKING:
11+
from sentence_transformers import ( # type: ignore[import-not-found]
12+
SentenceTransformer,
13+
)
2014

2115

2216
def similarity(
@@ -94,12 +88,9 @@ def similarity_with_rapidfuzz(
9488
score_cutoff: Optional score cutoff below which to return 0.0.
9589
name: Name of the scorer.
9690
"""
97-
rapidfuzz_import_error_msg = (
98-
"RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz"
99-
)
100-
91+
rapidfuzz_import_error_msg = generate_import_error_msg("rapidfuzz", "text")
10192
try:
102-
fuzz.ratio("test", "test")
93+
from rapidfuzz import fuzz, utils # type: ignore[import-not-found]
10394
except ImportError:
10495
warn_at_user_stacklevel(rapidfuzz_import_error_msg, UserWarning)
10596

@@ -191,11 +182,11 @@ def string_distance(
191182
normalize: Normalize distances and convert to similarity scores.
192183
name: Name of the scorer.
193184
"""
194-
rapidfuzz_import_error_msg = (
195-
"RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz"
196-
)
185+
rapidfuzz_import_error_msg = generate_import_error_msg("rapidfuzz", "text")
197186

198187
try:
188+
from rapidfuzz import distance # type: ignore[import-not-found]
189+
199190
distance.Levenshtein.distance("test", "test")
200191
except ImportError:
201192
warn_at_user_stacklevel(rapidfuzz_import_error_msg, UserWarning)
@@ -260,12 +251,15 @@ def similarity_with_tf_idf(reference: str, *, name: str = "similarity") -> "Scor
260251
reference: The reference text (e.g., expected output).
261252
name: Name of the scorer.
262253
"""
263-
sklearn_import_error_msg = (
264-
"scikit-learn dependency is not installed. Please install it with: pip install scikit-learn"
265-
)
254+
sklearn_import_error_msg = generate_import_error_msg("scikit-learn", "text")
266255

267256
try:
268-
TfidfVectorizer()
257+
from sklearn.feature_extraction.text import ( # type: ignore[import-not-found]
258+
TfidfVectorizer,
259+
)
260+
from sklearn.metrics.pairwise import ( # type: ignore[import-not-found]
261+
cosine_similarity as sklearn_cosine_similarity,
262+
)
269263
except ImportError:
270264
warn_at_user_stacklevel(sklearn_import_error_msg, UserWarning)
271265

@@ -309,10 +303,13 @@ def similarity_with_sentence_transformers(
309303
model_name: The name of the sentence-transformer model to use.
310304
name: Name of the scorer.
311305
"""
312-
sentence_transformers_error_msg = "Sentence transformers dependency is not installed. Please install it with: pip install sentence-transformers"
306+
sentence_transformers_error_msg = generate_import_error_msg("sentence-transformers", "training")
313307

314308
try:
315-
SentenceTransformer(model_name)
309+
from sentence_transformers import ( # type: ignore[import-not-found]
310+
SentenceTransformer,
311+
util,
312+
)
316313
except ImportError:
317314
warn_at_user_stacklevel(sentence_transformers_error_msg, UserWarning)
318315

@@ -370,6 +367,16 @@ def similarity_with_litellm(
370367
or self-hosted models.
371368
name: Name of the scorer.
372369
"""
370+
litellm_import_error_msg = generate_import_error_msg("litellm", "text")
371+
try:
372+
import litellm
373+
except ImportError:
374+
warn_at_user_stacklevel(litellm_import_error_msg, UserWarning)
375+
376+
def disabled_evaluate(_: t.Any) -> Metric:
377+
return Metric(value=0.0, attributes={"error": litellm_import_error_msg})
378+
379+
return Scorer(disabled_evaluate, name=name)
373380

374381
async def evaluate(
375382
data: t.Any,
@@ -423,14 +430,19 @@ def bleu(
423430
weights: Weights for unigram, bigram, etc. Must sum to 1.
424431
name: Name of the scorer.
425432
"""
426-
nltk_import_error_msg = "NLTK dependency is not installed. Install with: pip install nltk && python -m nltk.downloader punkt"
433+
nltk_import_error_msg = generate_import_error_msg("nltk", "text")
427434

428435
try:
429-
# Check for the 'punkt' tokenizer data
436+
import nltk # type: ignore[import-not-found]
437+
from nltk.tokenize import ( # type: ignore[import-not-found]
438+
word_tokenize,
439+
)
440+
from nltk.translate.bleu_score import ( # type: ignore[import-not-found]
441+
sentence_bleu,
442+
)
443+
430444
try:
431445
nltk.data.find("tokenizers/punkt")
432-
word_tokenize("test")
433-
sentence_bleu([["test"]], ["test"])
434446
except LookupError as e:
435447
nltk_import_error_msg = (
436448
"NLTK 'punkt' tokenizer not found. Please run: python -m nltk.downloader punkt"

dreadnode/transforms/ascii_art.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from art import text2art # type: ignore[import-untyped]
2-
31
from dreadnode.meta import Config
42
from dreadnode.transforms.base import Transform
53

@@ -8,8 +6,8 @@ def ascii_art(font: str = "rand", *, name: str = "ascii_art") -> Transform[str,
86
"""Converts text into ASCII art using the 'art' library."""
97

108
try:
11-
text2art("test") # Test if art is working
12-
except (ImportError, AttributeError):
9+
from art import text2art # type: ignore[import-not-found]
10+
except ImportError:
1311
raise ImportError(
1412
"ASCII art dependency is not installed. Install with: pip install art"
1513
) from ImportError("art library not available")

dreadnode/transforms/perturbation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import typing as t
44
import unicodedata
55

6-
from confusables import confusable_characters # type: ignore[import-untyped]
7-
86
from dreadnode.meta import Config
97
from dreadnode.transforms.base import Transform
108

@@ -226,8 +224,10 @@ def unicode_confusable(
226224
"""
227225

228226
try:
229-
confusable_characters("a")
230-
except (ImportError, AttributeError):
227+
from confusables import ( # type: ignore[import-not-found]
228+
confusable_characters,
229+
)
230+
except ImportError:
231231
raise ImportError(
232232
"Confusables dependency is not installed. Install with: pip install confusables"
233233
) from ImportError("confusables library not available")

dreadnode/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def format_dict(data: dict[str, t.Any], max_length: int = 80) -> str:
154154
return f"{{{formatted}}}"
155155

156156

157+
def generate_import_error_msg(package_name: str, extras_name: str) -> str:
158+
return (
159+
f"Missing required package '{package_name}'. "
160+
f"Please install it with: pip install {package_name} or dreadnode[{extras_name}]"
161+
)
162+
163+
157164
# Types
158165

159166

0 commit comments

Comments
 (0)