|
1 | 1 | import typing as t |
2 | 2 | from difflib import SequenceMatcher |
3 | 3 |
|
4 | | -import litellm |
5 | | -import nltk # type: ignore[import-untyped] |
6 | | -from nltk.tokenize import word_tokenize # type: ignore[import-untyped] |
7 | | -from nltk.translate.bleu_score import sentence_bleu # type: ignore[import-untyped] |
8 | | -from rapidfuzz import distance, fuzz, utils |
9 | | -from sentence_transformers import SentenceTransformer, util |
10 | | -from sklearn.feature_extraction.text import TfidfVectorizer # type: ignore[import-untyped] |
11 | | -from sklearn.metrics.pairwise import ( # type: ignore # noqa: PGH003 |
12 | | - cosine_similarity as sklearn_cosine_similarity, |
13 | | -) |
14 | | - |
15 | 4 | from dreadnode.meta import Config |
16 | 5 | from dreadnode.metric import Metric |
17 | 6 | from dreadnode.scorers.base import Scorer |
18 | 7 | from dreadnode.scorers.util import cosine_similarity |
19 | | -from dreadnode.util import warn_at_user_stacklevel |
| 8 | +from dreadnode.util import generate_import_error_msg, warn_at_user_stacklevel |
| 9 | + |
| 10 | +if t.TYPE_CHECKING: |
| 11 | + from sentence_transformers import ( # type: ignore[import-not-found] |
| 12 | + SentenceTransformer, |
| 13 | + ) |
20 | 14 |
|
21 | 15 |
|
22 | 16 | def similarity( |
@@ -94,12 +88,9 @@ def similarity_with_rapidfuzz( |
94 | 88 | score_cutoff: Optional score cutoff below which to return 0.0. |
95 | 89 | name: Name of the scorer. |
96 | 90 | """ |
97 | | - rapidfuzz_import_error_msg = ( |
98 | | - "RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz" |
99 | | - ) |
100 | | - |
| 91 | + rapidfuzz_import_error_msg = generate_import_error_msg("rapidfuzz", "text") |
101 | 92 | try: |
102 | | - fuzz.ratio("test", "test") |
| 93 | + from rapidfuzz import fuzz, utils # type: ignore[import-not-found] |
103 | 94 | except ImportError: |
104 | 95 | warn_at_user_stacklevel(rapidfuzz_import_error_msg, UserWarning) |
105 | 96 |
|
@@ -191,11 +182,11 @@ def string_distance( |
191 | 182 | normalize: Normalize distances and convert to similarity scores. |
192 | 183 | name: Name of the scorer. |
193 | 184 | """ |
194 | | - rapidfuzz_import_error_msg = ( |
195 | | - "RapidFuzz dependency is not installed. Please install it with: pip install rapidfuzz" |
196 | | - ) |
| 185 | + rapidfuzz_import_error_msg = generate_import_error_msg("rapidfuzz", "text") |
197 | 186 |
|
198 | 187 | try: |
| 188 | + from rapidfuzz import distance # type: ignore[import-not-found] |
| 189 | + |
199 | 190 | distance.Levenshtein.distance("test", "test") |
200 | 191 | except ImportError: |
201 | 192 | warn_at_user_stacklevel(rapidfuzz_import_error_msg, UserWarning) |
@@ -260,12 +251,15 @@ def similarity_with_tf_idf(reference: str, *, name: str = "similarity") -> "Scor |
260 | 251 | reference: The reference text (e.g., expected output). |
261 | 252 | name: Name of the scorer. |
262 | 253 | """ |
263 | | - sklearn_import_error_msg = ( |
264 | | - "scikit-learn dependency is not installed. Please install it with: pip install scikit-learn" |
265 | | - ) |
| 254 | + sklearn_import_error_msg = generate_import_error_msg("scikit-learn", "text") |
266 | 255 |
|
267 | 256 | try: |
268 | | - TfidfVectorizer() |
| 257 | + from sklearn.feature_extraction.text import ( # type: ignore[import-not-found] |
| 258 | + TfidfVectorizer, |
| 259 | + ) |
| 260 | + from sklearn.metrics.pairwise import ( # type: ignore[import-not-found] |
| 261 | + cosine_similarity as sklearn_cosine_similarity, |
| 262 | + ) |
269 | 263 | except ImportError: |
270 | 264 | warn_at_user_stacklevel(sklearn_import_error_msg, UserWarning) |
271 | 265 |
|
@@ -309,10 +303,13 @@ def similarity_with_sentence_transformers( |
309 | 303 | model_name: The name of the sentence-transformer model to use. |
310 | 304 | name: Name of the scorer. |
311 | 305 | """ |
312 | | - sentence_transformers_error_msg = "Sentence transformers dependency is not installed. Please install it with: pip install sentence-transformers" |
| 306 | + sentence_transformers_error_msg = generate_import_error_msg("sentence-transformers", "training") |
313 | 307 |
|
314 | 308 | try: |
315 | | - SentenceTransformer(model_name) |
| 309 | + from sentence_transformers import ( # type: ignore[import-not-found] |
| 310 | + SentenceTransformer, |
| 311 | + util, |
| 312 | + ) |
316 | 313 | except ImportError: |
317 | 314 | warn_at_user_stacklevel(sentence_transformers_error_msg, UserWarning) |
318 | 315 |
|
@@ -370,6 +367,16 @@ def similarity_with_litellm( |
370 | 367 | or self-hosted models. |
371 | 368 | name: Name of the scorer. |
372 | 369 | """ |
| 370 | + litellm_import_error_msg = generate_import_error_msg("litellm", "text") |
| 371 | + try: |
| 372 | + import litellm |
| 373 | + except ImportError: |
| 374 | + warn_at_user_stacklevel(litellm_import_error_msg, UserWarning) |
| 375 | + |
| 376 | + def disabled_evaluate(_: t.Any) -> Metric: |
| 377 | + return Metric(value=0.0, attributes={"error": litellm_import_error_msg}) |
| 378 | + |
| 379 | + return Scorer(disabled_evaluate, name=name) |
373 | 380 |
|
374 | 381 | async def evaluate( |
375 | 382 | data: t.Any, |
@@ -423,14 +430,19 @@ def bleu( |
423 | 430 | weights: Weights for unigram, bigram, etc. Must sum to 1. |
424 | 431 | name: Name of the scorer. |
425 | 432 | """ |
426 | | - nltk_import_error_msg = "NLTK dependency is not installed. Install with: pip install nltk && python -m nltk.downloader punkt" |
| 433 | + nltk_import_error_msg = generate_import_error_msg("nltk", "text") |
427 | 434 |
|
428 | 435 | try: |
429 | | - # Check for the 'punkt' tokenizer data |
| 436 | + import nltk # type: ignore[import-not-found] |
| 437 | + from nltk.tokenize import ( # type: ignore[import-not-found] |
| 438 | + word_tokenize, |
| 439 | + ) |
| 440 | + from nltk.translate.bleu_score import ( # type: ignore[import-not-found] |
| 441 | + sentence_bleu, |
| 442 | + ) |
| 443 | + |
430 | 444 | try: |
431 | 445 | nltk.data.find("tokenizers/punkt") |
432 | | - word_tokenize("test") |
433 | | - sentence_bleu([["test"]], ["test"]) |
434 | 446 | except LookupError as e: |
435 | 447 | nltk_import_error_msg = ( |
436 | 448 | "NLTK 'punkt' tokenizer not found. Please run: python -m nltk.downloader punkt" |
|
0 commit comments