11import math
22import datetime
3- from deepdiff . base import BaseProtocol
3+ from typing import TYPE_CHECKING , Callable , Protocol , Any
44from deepdiff .deephash import DeepHash
55from deepdiff .helper import (
66 DELTA_VIEW , numbers , strings , add_to_frozen_set , not_found , only_numbers , np , np_float64 , time_to_seconds ,
77 cartesian_product_numpy , np_ndarray , np_array_factory , get_homogeneous_numpy_compatible_type_of_seq , dict_ ,
88 CannotCompare )
99from collections .abc import Mapping , Iterable
1010
11+ if TYPE_CHECKING :
12+ from deepdiff .diff import DeepDiffProtocol
1113
12- DISTANCE_CALCS_NEEDS_CACHE = "Distance calculation can not happen once the cache is purged. Try with _cache='keep'"
14+ class DistanceProtocol (DeepDiffProtocol , Protocol ):
15+ hashes : dict
16+ deephash_parameters : dict
17+ iterable_compare_func : Callable | None
18+ math_epsilon : float
19+ cutoff_distance_for_pairs : float
20+
21+ def __get_item_rough_length (self , item , parent :str = "root" ) -> float :
22+ ...
1323
24+ def _to_delta_dict (
25+ self ,
26+ directed : bool = True ,
27+ report_repetition_required : bool = True ,
28+ always_include_values : bool = False ,
29+ ) -> dict :
30+ ...
1431
32+ def __calculate_item_deephash (self , item : Any ) -> None :
33+ ...
1534
1635
17- class DistanceMixin (BaseProtocol ):
1836
19- def _get_rough_distance (self ):
37+ DISTANCE_CALCS_NEEDS_CACHE = "Distance calculation can not happen once the cache is purged. Try with _cache='keep'"
38+
39+
40+ class DistanceMixin :
41+
42+ def _get_rough_distance (self : "DistanceProtocol" ):
2043 """
2144 Gives a numeric value for the distance of t1 and t2 based on how many operations are needed to convert
2245 one to the other.
@@ -51,7 +74,7 @@ def _get_rough_distance(self):
5174
5275 return diff_length / (t1_len + t2_len )
5376
54- def __get_item_rough_length (self , item , parent = 'root' ):
77+ def __get_item_rough_length (self : "DistanceProtocol" , item , parent = 'root' ):
5578 """
5679 Get the rough length of an item.
5780 It is used as a part of calculating the rough distance between objects.
@@ -69,7 +92,7 @@ def __get_item_rough_length(self, item, parent='root'):
6992 length = DeepHash .get_key (self .hashes , key = item , default = None , extract_index = 1 )
7093 return length
7194
72- def __calculate_item_deephash (self , item ) :
95+ def __calculate_item_deephash (self : "DistanceProtocol" , item : Any ) -> None :
7396 DeepHash (
7497 item ,
7598 hashes = self .hashes ,
@@ -79,8 +102,7 @@ def __calculate_item_deephash(self, item):
79102 )
80103
81104 def _precalculate_distance_by_custom_compare_func (
82- self , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
83-
105+ self : "DistanceProtocol" , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
84106 pre_calced_distances = dict_ ()
85107 for added_hash in hashes_added :
86108 for removed_hash in hashes_removed :
@@ -99,7 +121,7 @@ def _precalculate_distance_by_custom_compare_func(
99121 return pre_calced_distances
100122
101123 def _precalculate_numpy_arrays_distance (
102- self , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
124+ self : "DistanceProtocol" , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
103125
104126 # We only want to deal with 1D arrays.
105127 if isinstance (t2_hashtable [next (iter (hashes_added ))].item , (np_ndarray , list )):
@@ -203,7 +225,7 @@ def _get_numbers_distance(num1, num2, max_=1, use_log_scale=False, log_scale_sim
203225 return 0
204226 if use_log_scale :
205227 distance = logarithmic_distance (num1 , num2 )
206- if distance < logarithmic_distance :
228+ if distance < 0 :
207229 return 0
208230 return distance
209231 if not isinstance (num1 , float ):
@@ -246,7 +268,7 @@ def numpy_apply_log_keep_sign(array, offset=MATH_LOG_OFFSET):
246268 return signed_log_values
247269
248270
249- def logarithmic_similarity (a : numbers , b : numbers , threshold : float = 0.1 ):
271+ def logarithmic_similarity (a : numbers , b : numbers , threshold : float = 0.1 ) -> float :
250272 """
251273 A threshold of 0.1 translates to about 10.5% difference.
252274 A threshold of 0.5 translates to about 65% difference.
@@ -255,7 +277,7 @@ def logarithmic_similarity(a: numbers, b: numbers, threshold: float=0.1):
255277 return logarithmic_distance (a , b ) < threshold
256278
257279
258- def logarithmic_distance (a : numbers , b : numbers ):
280+ def logarithmic_distance (a : numbers , b : numbers ) -> float :
259281 # Apply logarithm to the absolute values and consider the sign
260282 a = float (a )
261283 b = float (b )
0 commit comments