Skip to content

Commit c9db887

Browse files
committed
fix: text extras import errors
1 parent 5d99ab3 commit c9db887

File tree

9 files changed

+94
-85
lines changed

9 files changed

+94
-85
lines changed

dreadnode/scorers/classification.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import typing as t
22

3-
from transformers import pipeline
4-
53
from dreadnode.meta import Config
64
from dreadnode.metric import Metric
75
from dreadnode.scorers import Scorer
@@ -37,7 +35,7 @@ def zero_shot_classification(
3735
)
3836

3937
try:
40-
pipeline("zero-shot-classification", model=model_name)
38+
from transformers import pipeline # type: ignore[import-not-found]
4139
except ImportError:
4240
warn_at_user_stacklevel(transformers_error_msg, UserWarning)
4341

dreadnode/scorers/pii.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import re
22
import typing as t
33

4-
from presidio_analyzer import AnalyzerEngine
5-
from presidio_analyzer.nlp_engine import NlpEngineProvider
6-
74
from dreadnode.metric import Metric
85
from dreadnode.scorers import Scorer
96
from dreadnode.scorers.contains import contains
107
from dreadnode.util import warn_at_user_stacklevel
118

129
if t.TYPE_CHECKING:
10+
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
11+
1312
from dreadnode.types import JsonDict
1413

1514

@@ -66,6 +65,9 @@ def _get_presidio_analyzer() -> "AnalyzerEngine":
6665
"""Lazily initializes and returns a singleton Presidio AnalyzerEngine instance."""
6766
global g_analyzer_engine # noqa: PLW0603
6867

68+
from presidio_analyzer import AnalyzerEngine # type: ignore[import-not-found]
69+
from presidio_analyzer.nlp_engine import NlpEngineProvider # type: ignore[import-not-found]
70+
6971
if g_analyzer_engine is None:
7072
provider = NlpEngineProvider(
7173
nlp_configuration={
@@ -107,8 +109,8 @@ def detect_pii_with_presidio(
107109
)
108110

109111
try:
110-
_get_presidio_analyzer()
111-
except (ImportError, OSError):
112+
import presidio_analyzer # type: ignore[import-not-found,unused-ignore]
113+
except ImportError:
112114
warn_at_user_stacklevel(presidio_import_error_msg, UserWarning)
113115

114116
def disabled_evaluate(_: t.Any) -> Metric:

dreadnode/scorers/readability.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import typing as t
22

3-
import textstat # type: ignore[import-untyped]
4-
53
from dreadnode.metric import Metric
64
from dreadnode.scorers.base import Scorer
75
from dreadnode.util import warn_at_user_stacklevel
@@ -29,8 +27,8 @@ def readability(
2927
)
3028

3129
try:
32-
textstat.flesch_kincaid_grade("test")
33-
except (ImportError, AttributeError):
30+
import textstat # type: ignore[import-not-found]
31+
except ImportError:
3432
warn_at_user_stacklevel(textstat_import_error_msg, UserWarning)
3533

3634
def disabled_evaluate(_: t.Any) -> Metric:

dreadnode/scorers/sentiment.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import typing as t
33

44
import httpx
5-
from textblob import TextBlob # type: ignore[import-untyped]
65

76
from dreadnode.meta import Config
87
from dreadnode.metric import Metric
@@ -33,8 +32,8 @@ def sentiment(
3332
textblob_import_error_msg = "TextBlob dependency is not installed. Install with: pip install textblob && python -m textblob.download_corpora"
3433

3534
try:
36-
TextBlob("test").sentiment # noqa: B018
37-
except (ImportError, AttributeError):
35+
from textblob import TextBlob # type: ignore[import-not-found]
36+
except ImportError:
3837
warn_at_user_stacklevel(textblob_import_error_msg, UserWarning)
3938

4039
def disabled_evaluate(_: t.Any) -> Metric:

dreadnode/scorers/similarity.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
import typing as t
22
from difflib import SequenceMatcher
33

4-
import litellm
5-
import nltk # type: ignore[import-untyped]
6-
from nltk.tokenize import word_tokenize # type: ignore[import-untyped]
7-
from nltk.translate.bleu_score import sentence_bleu # type: ignore[import-untyped]
8-
from rapidfuzz import distance, fuzz, utils
9-
from sentence_transformers import SentenceTransformer, util
10-
from sklearn.feature_extraction.text import TfidfVectorizer # type: ignore[import-untyped]
11-
from sklearn.metrics.pairwise import ( # type: ignore # noqa: PGH003
12-
cosine_similarity as sklearn_cosine_similarity,
13-
)
14-
154
from dreadnode.meta import Config
165
from dreadnode.metric import Metric
176
from dreadnode.scorers.base import Scorer
187
from dreadnode.scorers.util import cosine_similarity
198
from dreadnode.util import warn_at_user_stacklevel
209

10+
if t.TYPE_CHECKING:
11+
from sentence_transformers import ( # type: ignore[import-not-found]
12+
SentenceTransformer,
13+
)
14+
2115

2216
def similarity(
2317
reference: str,
@@ -94,12 +88,9 @@ def similarity_with_rapidfuzz(
9488
score_cutoff: Optional score cutoff below which to return 0.0.
9589
name: Name of the scorer.
9690
"""
97-
rapidfuzz_import_error_msg = (
98-
"RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz"
99-
)
100-
91+
rapidfuzz_import_error_msg = "RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz or dreadnode[text]"
10192
try:
102-
fuzz.ratio("test", "test")
93+
from rapidfuzz import fuzz, utils # type: ignore[import-not-found]
10394
except ImportError:
10495
warn_at_user_stacklevel(rapidfuzz_import_error_msg, UserWarning)
10596

@@ -191,11 +182,11 @@ def string_distance(
191182
normalize: Normalize distances and convert to similarity scores.
192183
name: Name of the scorer.
193184
"""
194-
rapidfuzz_import_error_msg = (
195-
"RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz"
196-
)
185+
rapidfuzz_import_error_msg = "RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz or dreadnode[text]"
197186

198187
try:
188+
from rapidfuzz import distance # type: ignore[import-not-found]
189+
199190
distance.Levenshtein.distance("test", "test")
200191
except ImportError:
201192
warn_at_user_stacklevel(rapidfuzz_import_error_msg, UserWarning)
@@ -260,12 +251,15 @@ def similarity_with_tf_idf(reference: str, *, name: str = "similarity") -> "Scor
260251
reference: The reference text (e.g., expected output).
261252
name: Name of the scorer.
262253
"""
263-
sklearn_import_error_msg = (
264-
"scikit-learn dependency is not installed. Please install it with: pip install scikit-learn"
265-
)
254+
sklearn_import_error_msg = "scikit-learn dependency is not installed. Please install it with: pip install scikit-learn or dreadnode[text]"
266255

267256
try:
268-
TfidfVectorizer()
257+
from sklearn.feature_extraction.text import ( # type: ignore[import-not-found]
258+
TfidfVectorizer,
259+
)
260+
from sklearn.metrics.pairwise import ( # type: ignore[import-not-found]
261+
cosine_similarity as sklearn_cosine_similarity,
262+
)
269263
except ImportError:
270264
warn_at_user_stacklevel(sklearn_import_error_msg, UserWarning)
271265

@@ -275,6 +269,7 @@ def disabled_evaluate(_: t.Any) -> Metric:
275269
return Scorer(disabled_evaluate, name=name)
276270

277271
vectorizer = TfidfVectorizer(stop_words="english")
272+
a = 1
278273

279274
def evaluate(data: t.Any, *, reference: str = reference) -> Metric:
280275
candidate_text = str(data)
@@ -309,10 +304,13 @@ def similarity_with_sentence_transformers(
309304
model_name: The name of the sentence-transformer model to use.
310305
name: Name of the scorer.
311306
"""
312-
sentence_transformers_error_msg = "Sentence transformers dependency is not installed. Please install it with: pip install sentence-transformers"
307+
sentence_transformers_error_msg = "Sentence transformers dependency is not installed. Please install it with: pip install sentence-transformers or dreadnode[training]"
313308

314309
try:
315-
SentenceTransformer(model_name)
310+
from sentence_transformers import ( # type: ignore[import-not-found]
311+
SentenceTransformer,
312+
util,
313+
)
316314
except ImportError:
317315
warn_at_user_stacklevel(sentence_transformers_error_msg, UserWarning)
318316

@@ -370,6 +368,16 @@ def similarity_with_litellm(
370368
or self-hosted models.
371369
name: Name of the scorer.
372370
"""
371+
litellm_import_error_msg = "litellm dependency is not installed. Please install it with: pip install litellm or dreadnode[text]"
372+
try:
373+
import litellm
374+
except ImportError:
375+
warn_at_user_stacklevel(litellm_import_error_msg, UserWarning)
376+
377+
def disabled_evaluate(_: t.Any) -> Metric:
378+
return Metric(value=0.0, attributes={"error": litellm_import_error_msg})
379+
380+
return Scorer(disabled_evaluate, name=name)
373381

374382
async def evaluate(
375383
data: t.Any,
@@ -426,11 +434,16 @@ def bleu(
426434
nltk_import_error_msg = "NLTK dependency is not installed. Install with: pip install nltk && python -m nltk.downloader punkt"
427435

428436
try:
429-
# Check for the 'punkt' tokenizer data
437+
import nltk # type: ignore[import-not-found]
438+
from nltk.tokenize import ( # type: ignore[import-not-found]
439+
word_tokenize,
440+
)
441+
from nltk.translate.bleu_score import ( # type: ignore[import-not-found]
442+
sentence_bleu,
443+
)
444+
430445
try:
431446
nltk.data.find("tokenizers/punkt")
432-
word_tokenize("test")
433-
sentence_bleu([["test"]], ["test"])
434447
except LookupError as e:
435448
nltk_import_error_msg = (
436449
"NLTK 'punkt' tokenizer not found. Please run: python -m nltk.downloader punkt"

dreadnode/transforms/ascii_art.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from art import text2art # type: ignore[import-untyped]
2-
31
from dreadnode.meta import Config
42
from dreadnode.transforms.base import Transform
53

@@ -8,8 +6,8 @@ def ascii_art(font: str = "rand", *, name: str = "ascii_art") -> Transform[str,
86
"""Converts text into ASCII art using the 'art' library."""
97

108
try:
11-
text2art("test") # Test if art is working
12-
except (ImportError, AttributeError):
9+
from art import text2art # type: ignore[import-not-found]
10+
except ImportError:
1311
raise ImportError(
1412
"ASCII art dependency is not installed. Install with: pip install art"
1513
) from ImportError("art library not available")

dreadnode/transforms/perturbation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import typing as t
44
import unicodedata
55

6-
from confusables import confusable_characters # type: ignore[import-untyped]
7-
86
from dreadnode.meta import Config
97
from dreadnode.transforms.base import Transform
108

@@ -226,8 +224,10 @@ def unicode_confusable(
226224
"""
227225

228226
try:
229-
confusable_characters("a")
230-
except (ImportError, AttributeError):
227+
from confusables import ( # type: ignore[import-not-found]
228+
confusable_characters,
229+
)
230+
except ImportError:
231231
raise ImportError(
232232
"Confusables dependency is not installed. Install with: pip install confusables"
233233
) from ImportError("confusables library not available")

0 commit comments

Comments
 (0)