Skip to content

Commit 22ef165

Browse files
committed
implement __eq__ for HashAlgorithm instances
1 parent 6e8fe85 commit 22ef165

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

src/cryptography/hazmat/primitives/hashes.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import typing
89

910
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
1011
from cryptography.utils import Buffer
@@ -36,6 +37,13 @@
3637

3738

3839
class HashAlgorithm(metaclass=abc.ABCMeta):
40+
@abc.abstractmethod
41+
def __eq__(self, other: typing.Any) -> bool:
42+
"""
43+
Implement equality checking.
44+
"""
45+
...
46+
3947
@property
4048
@abc.abstractmethod
4149
def name(self) -> str:
@@ -103,66 +111,99 @@ class SHA1(HashAlgorithm):
103111
digest_size = 20
104112
block_size = 64
105113

114+
def __eq__(self, other: typing.Any) -> bool:
115+
return isinstance(other, SHA1)
116+
106117

107118
class SHA512_224(HashAlgorithm): # noqa: N801
108119
name = "sha512-224"
109120
digest_size = 28
110121
block_size = 128
111122

123+
def __eq__(self, other: typing.Any) -> bool:
124+
return isinstance(other, SHA512_224)
125+
112126

113127
class SHA512_256(HashAlgorithm): # noqa: N801
114128
name = "sha512-256"
115129
digest_size = 32
116130
block_size = 128
117131

132+
def __eq__(self, other: typing.Any) -> bool:
133+
return isinstance(other, SHA512_256)
134+
118135

119136
class SHA224(HashAlgorithm):
120137
name = "sha224"
121138
digest_size = 28
122139
block_size = 64
123140

141+
def __eq__(self, other: typing.Any) -> bool:
142+
return isinstance(other, SHA224)
143+
124144

125145
class SHA256(HashAlgorithm):
126146
name = "sha256"
127147
digest_size = 32
128148
block_size = 64
129149

150+
def __eq__(self, other: typing.Any) -> bool:
151+
return isinstance(other, SHA256)
152+
130153

131154
class SHA384(HashAlgorithm):
132155
name = "sha384"
133156
digest_size = 48
134157
block_size = 128
135158

159+
def __eq__(self, other: typing.Any) -> bool:
160+
return isinstance(other, SHA384)
161+
136162

137163
class SHA512(HashAlgorithm):
138164
name = "sha512"
139165
digest_size = 64
140166
block_size = 128
141167

168+
def __eq__(self, other: typing.Any) -> bool:
169+
return isinstance(other, SHA512)
170+
142171

143172
class SHA3_224(HashAlgorithm): # noqa: N801
144173
name = "sha3-224"
145174
digest_size = 28
146175
block_size = None
147176

177+
def __eq__(self, other: typing.Any) -> bool:
178+
return isinstance(other, SHA3_224)
179+
148180

149181
class SHA3_256(HashAlgorithm): # noqa: N801
150182
name = "sha3-256"
151183
digest_size = 32
152184
block_size = None
153185

186+
def __eq__(self, other: typing.Any) -> bool:
187+
return isinstance(other, SHA3_256)
188+
154189

155190
class SHA3_384(HashAlgorithm): # noqa: N801
156191
name = "sha3-384"
157192
digest_size = 48
158193
block_size = None
159194

195+
def __eq__(self, other: typing.Any) -> bool:
196+
return isinstance(other, SHA3_384)
197+
160198

161199
class SHA3_512(HashAlgorithm): # noqa: N801
162200
name = "sha3-512"
163201
digest_size = 64
164202
block_size = None
165203

204+
def __eq__(self, other: typing.Any) -> bool:
205+
return isinstance(other, SHA3_512)
206+
166207

167208
class SHAKE128(HashAlgorithm, ExtendableOutputFunction):
168209
name = "shake128"
@@ -177,6 +218,12 @@ def __init__(self, digest_size: int):
177218

178219
self._digest_size = digest_size
179220

221+
def __eq__(self, other: typing.Any) -> bool:
222+
return (
223+
isinstance(other, SHAKE128)
224+
and self._digest_size == other._digest_size
225+
)
226+
180227
@property
181228
def digest_size(self) -> int:
182229
return self._digest_size
@@ -195,6 +242,12 @@ def __init__(self, digest_size: int):
195242

196243
self._digest_size = digest_size
197244

245+
def __eq__(self, other: typing.Any) -> bool:
246+
return (
247+
isinstance(other, SHAKE256)
248+
and self._digest_size == other._digest_size
249+
)
250+
198251
@property
199252
def digest_size(self) -> int:
200253
return self._digest_size
@@ -205,6 +258,9 @@ class MD5(HashAlgorithm):
205258
digest_size = 16
206259
block_size = 64
207260

261+
def __eq__(self, other: typing.Any) -> bool:
262+
return isinstance(other, MD5)
263+
208264

209265
class BLAKE2b(HashAlgorithm):
210266
name = "blake2b"
@@ -218,6 +274,12 @@ def __init__(self, digest_size: int):
218274

219275
self._digest_size = digest_size
220276

277+
def __eq__(self, other: typing.Any) -> bool:
278+
return (
279+
isinstance(other, BLAKE2b)
280+
and self._digest_size == other._digest_size
281+
)
282+
221283
@property
222284
def digest_size(self) -> int:
223285
return self._digest_size
@@ -235,6 +297,12 @@ def __init__(self, digest_size: int):
235297

236298
self._digest_size = digest_size
237299

300+
def __eq__(self, other: typing.Any) -> bool:
301+
return (
302+
isinstance(other, BLAKE2s)
303+
and self._digest_size == other._digest_size
304+
)
305+
238306
@property
239307
def digest_size(self) -> int:
240308
return self._digest_size
@@ -244,3 +312,6 @@ class SM3(HashAlgorithm):
244312
name = "sm3"
245313
digest_size = 32
246314
block_size = 64
315+
316+
def __eq__(self, other: typing.Any) -> bool:
317+
return isinstance(other, SM3)

tests/doubles.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
33
# for complete details.
44

5+
import typing
56

67
from cryptography.hazmat.primitives import hashes, serialization
78
from cryptography.hazmat.primitives.asymmetric import padding
@@ -40,6 +41,12 @@ class DummyHashAlgorithm(hashes.HashAlgorithm):
4041
def __init__(self, digest_size: int = 32) -> None:
4142
self._digest_size = digest_size
4243

44+
def __eq__(self, other: typing.Any) -> bool:
45+
return (
46+
isinstance(self, DummyHashAlgorithm)
47+
and self._digest_size == other._digest_size
48+
)
49+
4350
@property
4451
def digest_size(self) -> int:
4552
return self._digest_size

tests/hazmat/primitives/test_hashes.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ...doubles import DummyHashAlgorithm
1414
from ...utils import raises_unsupported_algorithm
15-
from .utils import generate_base_hash_test
15+
from .utils import generate_base_hash_test, generate_eq_hash_test
1616

1717

1818
class TestHashContext:
@@ -52,6 +52,7 @@ class TestSHA1:
5252
hashes.SHA1(),
5353
digest_size=20,
5454
)
55+
test_sha1_eq = generate_eq_hash_test(hashes.SHA1())
5556

5657

5758
@pytest.mark.supported(
@@ -63,6 +64,7 @@ class TestSHA224:
6364
hashes.SHA224(),
6465
digest_size=28,
6566
)
67+
test_sha224_eq = generate_eq_hash_test(hashes.SHA224())
6668

6769

6870
@pytest.mark.supported(
@@ -74,6 +76,7 @@ class TestSHA256:
7476
hashes.SHA256(),
7577
digest_size=32,
7678
)
79+
test_sha256_eq = generate_eq_hash_test(hashes.SHA256())
7780

7881

7982
@pytest.mark.supported(
@@ -85,6 +88,7 @@ class TestSHA384:
8588
hashes.SHA384(),
8689
digest_size=48,
8790
)
91+
test_sha384_eq = generate_eq_hash_test(hashes.SHA384())
8892

8993

9094
@pytest.mark.supported(
@@ -96,6 +100,79 @@ class TestSHA512:
96100
hashes.SHA512(),
97101
digest_size=64,
98102
)
103+
test_sha512_eq = generate_eq_hash_test(hashes.SHA512())
104+
105+
106+
@pytest.mark.supported(
107+
only_if=lambda backend: backend.hash_supported(hashes.SHA512_224()),
108+
skip_message="Does not support SHA512 224",
109+
)
110+
class TestSHA512224:
111+
test_sha512_224 = generate_base_hash_test(
112+
hashes.SHA512_224(),
113+
digest_size=28,
114+
)
115+
test_sha512_224_eq = generate_eq_hash_test(hashes.SHA512_224())
116+
117+
118+
@pytest.mark.supported(
119+
only_if=lambda backend: backend.hash_supported(hashes.SHA512_256()),
120+
skip_message="Does not support SHA512 256",
121+
)
122+
class TestSHA512256:
123+
test_sha512_256 = generate_base_hash_test(
124+
hashes.SHA512_256(),
125+
digest_size=32,
126+
)
127+
test_sha512_256_eq = generate_eq_hash_test(hashes.SHA512_256())
128+
129+
130+
@pytest.mark.supported(
131+
only_if=lambda backend: backend.hash_supported(hashes.SHA3_224()),
132+
skip_message="Does not support SHA3 224",
133+
)
134+
class TestSHA3224:
135+
test_sha3_224 = generate_base_hash_test(
136+
hashes.SHA3_224(),
137+
digest_size=28,
138+
)
139+
test_sha3_224_eq = generate_eq_hash_test(hashes.SHA3_224())
140+
141+
142+
@pytest.mark.supported(
143+
only_if=lambda backend: backend.hash_supported(hashes.SHA3_256()),
144+
skip_message="Does not support SHA3 256",
145+
)
146+
class TestSHA3256:
147+
test_sha3_256 = generate_base_hash_test(
148+
hashes.SHA3_256(),
149+
digest_size=32,
150+
)
151+
test_sha3_256_eq = generate_eq_hash_test(hashes.SHA3_256())
152+
153+
154+
@pytest.mark.supported(
155+
only_if=lambda backend: backend.hash_supported(hashes.SHA3_384()),
156+
skip_message="Does not support SHA3 384",
157+
)
158+
class TestSHA3384:
159+
test_sha3_384 = generate_base_hash_test(
160+
hashes.SHA3_384(),
161+
digest_size=48,
162+
)
163+
test_sha3_384_eq = generate_eq_hash_test(hashes.SHA3_384())
164+
165+
166+
@pytest.mark.supported(
167+
only_if=lambda backend: backend.hash_supported(hashes.SHA3_512()),
168+
skip_message="Does not support SHA3 512",
169+
)
170+
class TestSHA3512:
171+
test_sha3_512 = generate_base_hash_test(
172+
hashes.SHA3_512(),
173+
digest_size=64,
174+
)
175+
test_sha3_512_eq = generate_eq_hash_test(hashes.SHA3_512())
99176

100177

101178
@pytest.mark.supported(
@@ -107,6 +184,7 @@ class TestMD5:
107184
hashes.MD5(),
108185
digest_size=16,
109186
)
187+
test_md5_eq = generate_eq_hash_test(hashes.MD5())
110188

111189

112190
@pytest.mark.supported(
@@ -120,6 +198,7 @@ class TestBLAKE2b:
120198
hashes.BLAKE2b(digest_size=64),
121199
digest_size=64,
122200
)
201+
test_blake2b_eq = generate_eq_hash_test(hashes.BLAKE2b(digest_size=64))
123202

124203
def test_invalid_digest_size(self, backend):
125204
with pytest.raises(ValueError):
@@ -143,6 +222,7 @@ class TestBLAKE2s:
143222
hashes.BLAKE2s(digest_size=32),
144223
digest_size=32,
145224
)
225+
test_blake2s_eq = generate_eq_hash_test(hashes.BLAKE2s(digest_size=32))
146226

147227
def test_invalid_digest_size(self, backend):
148228
with pytest.raises(ValueError):
@@ -165,6 +245,14 @@ def test_buffer_protocol_hash(backend):
165245

166246

167247
class TestSHAKE:
248+
@pytest.mark.parametrize("xof", [hashes.SHAKE128, hashes.SHAKE256])
249+
def test_eq(self, xof):
250+
value_one = xof(digest_size=32)
251+
value_two = xof(digest_size=32) # identical
252+
value_three = xof(digest_size=64)
253+
assert value_one == value_two
254+
assert value_one != value_three
255+
168256
@pytest.mark.parametrize("xof", [hashes.SHAKE128, hashes.SHAKE256])
169257
def test_invalid_digest_type(self, xof):
170258
with pytest.raises(TypeError):
@@ -188,3 +276,4 @@ class TestSM3:
188276
hashes.SM3(),
189277
digest_size=32,
190278
)
279+
test_sm3_eq = generate_eq_hash_test(hashes.SM3())

tests/hazmat/primitives/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Mode,
3636
)
3737

38+
from ...doubles import DummyHashAlgorithm
3839
from ...utils import load_vectors_from_file
3940

4041

@@ -207,6 +208,14 @@ def test_base_hash(self, backend):
207208
return test_base_hash
208209

209210

211+
def generate_eq_hash_test(algorithm):
212+
def test_eq(self):
213+
assert algorithm == algorithm
214+
assert algorithm != DummyHashAlgorithm()
215+
216+
return test_eq
217+
218+
210219
def base_hash_test(backend, algorithm, digest_size):
211220
m = hashes.Hash(algorithm, backend=backend)
212221
assert m.algorithm.digest_size == digest_size

0 commit comments

Comments
 (0)