Skip to content

Commit 24afc18

Browse files
authored
Merge pull request #231 from tomato42/fix_equality_tests
Fix (in)equality tests
2 parents e0626d0 + 7c5b3da commit 24afc18

File tree

6 files changed

+66
-31
lines changed

6 files changed

+66
-31
lines changed

src/ecdsa/ecdsa.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,20 @@ def __init__(self, generator, point, verify=True):
145145
raise InvalidPointError("Generator point order is bad.")
146146

147147
def __eq__(self, other):
148+
"""Return True if the keys are identical, False otherwise.
149+
150+
Note: for comparison, only placement on the same curve and point
151+
equality is considered, use of the same generator point is not
152+
considered.
153+
"""
148154
if isinstance(other, Public_key):
149-
"""Return True if the points are identical, False otherwise."""
150155
return self.curve == other.curve and self.point == other.point
151156
return NotImplemented
152157

158+
def __ne__(self, other):
159+
"""Return False if the keys are identical, True otherwise."""
160+
return not self == other
161+
153162
def verifies(self, hash, signature):
154163
"""Verify that signature is a valid signature of hash.
155164
Return True if the signature is valid.
@@ -188,14 +197,18 @@ def __init__(self, public_key, secret_multiplier):
188197
self.secret_multiplier = secret_multiplier
189198

190199
def __eq__(self, other):
200+
"""Return True if the points are identical, False otherwise."""
191201
if isinstance(other, Private_key):
192-
"""Return True if the points are identical, False otherwise."""
193202
return (
194203
self.public_key == other.public_key
195204
and self.secret_multiplier == other.secret_multiplier
196205
)
197206
return NotImplemented
198207

208+
def __ne__(self, other):
209+
"""Return False if the points are identical, True otherwise."""
210+
return not self == other
211+
199212
def sign(self, hash, random_k):
200213
"""Return a signature for the provided hash, using the provided
201214
random nonce. It is absolutely vital that random_k be an unpredictable

src/ecdsa/ellipticcurve.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525
# Signature checking (5.4.2):
2626
# - Verify that r and s are in [1,n-1].
2727
#
28-
# Version of 2008.11.25.
29-
#
3028
# Revision history:
3129
# 2005.12.31 - Initial version.
3230
# 2008.11.25 - Change CurveFp.is_on to contains_point.
3331
#
3432
# Written in 2005 by Peter Pearson and placed in the public domain.
33+
# Modified extensively as part of python-ecdsa.
3534

3635
from __future__ import division
3736

@@ -92,8 +91,14 @@ def __init__(self, p, a, b, h=None):
9291
self.__h = h
9392

9493
def __eq__(self, other):
94+
"""Return True if other is an identical curve, False otherwise.
95+
96+
Note: the value of the cofactor of the curve is not taken into account
97+
when comparing curves, as it's derived from the base point and
98+
intrinsic curve characteristic (but it's complex to compute),
99+
only the prime and curve parameters are considered.
100+
"""
95101
if isinstance(other, CurveFp):
96-
"""Return True if the curves are identical, False otherwise."""
97102
return (
98103
self.__p == other.__p
99104
and self.__a == other.__a
@@ -102,7 +107,8 @@ def __eq__(self, other):
102107
return NotImplemented
103108

104109
def __ne__(self, other):
105-
return not (self == other)
110+
"""Return False if other is an identical curve, True otherwise."""
111+
return not self == other
106112

107113
def __hash__(self):
108114
return hash((self.__p, self.__a, self.__b))
@@ -158,7 +164,7 @@ def __init__(self, curve, x, y, z, order=None, generator=False):
158164
generator=True
159165
:param bool generator: the point provided is a curve generator, as
160166
such, it will be commonly used with scalar multiplication. This will
161-
cause to precompute multiplication table for it
167+
cause to precompute multiplication table generation for it
162168
"""
163169
self.__curve = curve
164170
# since it's generally better (faster) to use scaled points vs unscaled
@@ -224,7 +230,10 @@ def __setstate__(self, state):
224230
self._update_lock = RWLock()
225231

226232
def __eq__(self, other):
227-
"""Compare two points with each-other."""
233+
"""Compare for equality two points with each-other.
234+
235+
Note: only points that lie on the same curve can be equal.
236+
"""
228237
try:
229238
self._update_lock.reader_acquire()
230239
if other is INFINITY:
@@ -256,6 +265,10 @@ def __eq__(self, other):
256265
y1 * zz2 * z2 - y2 * zz1 * z1
257266
) % p == 0
258267

268+
def __ne__(self, other):
269+
"""Compare for inequality two points with each-other."""
270+
return not self == other
271+
259272
def order(self):
260273
"""Return the order of the point.
261274
@@ -757,7 +770,10 @@ def __init__(self, curve, x, y, order=None):
757770
assert self * order == INFINITY
758771

759772
def __eq__(self, other):
760-
"""Return True if the points are identical, False otherwise."""
773+
"""Return True if the points are identical, False otherwise.
774+
775+
Note: only points that lie on the same curve can be equal.
776+
"""
761777
if isinstance(other, Point):
762778
return (
763779
self.__curve == other.__curve
@@ -766,6 +782,10 @@ def __eq__(self, other):
766782
)
767783
return NotImplemented
768784

785+
def __ne__(self, other):
786+
"""Returns False if points are identical, True otherwise."""
787+
return not self == other
788+
769789
def __neg__(self):
770790
return Point(self.__curve, self.__x, self.__curve.p() - self.__y)
771791

src/ecdsa/keys.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def __eq__(self, other):
191191
return self.curve == other.curve and self.pubkey == other.pubkey
192192
return NotImplemented
193193

194+
def __ne__(self, other):
195+
"""Return False if the points are identical, True otherwise."""
196+
return not self == other
197+
194198
@classmethod
195199
def from_public_point(
196200
cls, point, curve=NIST192p, hashfunc=sha1, validate_point=True
@@ -817,6 +821,10 @@ def __eq__(self, other):
817821
)
818822
return NotImplemented
819823

824+
def __ne__(self, other):
825+
"""Return False if the points are identical, True otherwise."""
826+
return not self == other
827+
820828
@classmethod
821829
def generate(cls, curve=NIST192p, entropy=None, hashfunc=sha1):
822830
"""

src/ecdsa/test_der.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_minimal_with_high_bit_set(self):
4242
val, rem = remove_integer(b("\x02\x02\x00\x80"))
4343

4444
self.assertEqual(val, 0x80)
45-
self.assertFalse(rem)
45+
self.assertEqual(rem, b"")
4646

4747
def test_two_zero_bytes_with_high_bit_set(self):
4848
with self.assertRaises(UnexpectedDER):
@@ -60,19 +60,19 @@ def test_encoding_of_zero(self):
6060
val, rem = remove_integer(b("\x02\x01\x00"))
6161

6262
self.assertEqual(val, 0)
63-
self.assertFalse(rem)
63+
self.assertEqual(rem, b"")
6464

6565
def test_encoding_of_127(self):
6666
val, rem = remove_integer(b("\x02\x01\x7f"))
6767

6868
self.assertEqual(val, 127)
69-
self.assertFalse(rem)
69+
self.assertEqual(rem, b"")
7070

7171
def test_encoding_of_128(self):
7272
val, rem = remove_integer(b("\x02\x02\x00\x80"))
7373

7474
self.assertEqual(val, 128)
75-
self.assertFalse(rem)
75+
self.assertEqual(rem, b"")
7676

7777
def test_wrong_tag(self):
7878
with self.assertRaises(UnexpectedDER) as e:

src/ecdsa/test_keys.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,20 +198,16 @@ def test_equality_on_verifying_keys(self):
198198
self.assertEqual(self.vk, self.sk.get_verifying_key())
199199

200200
def test_inequality_on_verifying_keys(self):
201-
# use `==` to workaround instrumental <-> unittest compat issue
202-
self.assertFalse(self.vk == self.vk2)
201+
self.assertNotEqual(self.vk, self.vk2)
203202

204203
def test_inequality_on_verifying_keys_not_implemented(self):
205-
# use `==` to workaround instrumental <-> unittest compat issue
206-
self.assertFalse(self.vk == None)
204+
self.assertNotEqual(self.vk, None)
207205

208206
def test_VerifyingKey_inequality_on_same_curve(self):
209-
# use `==` to workaround instrumental <-> unittest compat issue
210-
self.assertFalse(self.vk == self.sk2.verifying_key)
207+
self.assertNotEqual(self.vk, self.sk2.verifying_key)
211208

212209
def test_SigningKey_inequality_on_same_curve(self):
213-
# use `==` to workaround instrumental <-> unittest compat issue
214-
self.assertFalse(self.sk == self.sk2)
210+
self.assertNotEqual(self.sk, self.sk2)
215211

216212

217213
class TestSigningKey(unittest.TestCase):
@@ -283,12 +279,10 @@ def test_verify_with_lazy_precompute(self):
283279
self.assertTrue(vk.verify(sig, b"other message"))
284280

285281
def test_inequality_on_signing_keys(self):
286-
# use `==` to workaround instrumental <-> unittest compat issue
287-
self.assertFalse(self.sk1 == self.sk2)
282+
self.assertNotEqual(self.sk1, self.sk2)
288283

289284
def test_inequality_on_signing_keys_not_implemented(self):
290-
# use `==` to workaround instrumental <-> unittest compat issue
291-
self.assertFalse(self.sk1 == None)
285+
self.assertNotEqual(self.sk1, None)
292286

293287

294288
# test VerifyingKey.verify()

src/ecdsa/test_pyecdsa.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,9 @@ def test_public_key_recovery(self):
653653
)
654654

655655
# Test if original vk is the list of recovered keys
656-
self.assertTrue(
657-
vk.pubkey.point
658-
in [recovered_vk.pubkey.point for recovered_vk in recovered_vks]
656+
self.assertIn(
657+
vk.pubkey.point,
658+
[recovered_vk.pubkey.point for recovered_vk in recovered_vks],
659659
)
660660

661661
def test_public_key_recovery_with_custom_hash(self):
@@ -684,9 +684,9 @@ def test_public_key_recovery_with_custom_hash(self):
684684
self.assertEqual(sha256, recovered_vk.default_hashfunc)
685685

686686
# Test if original vk is the list of recovered keys
687-
self.assertTrue(
688-
vk.pubkey.point
689-
in [recovered_vk.pubkey.point for recovered_vk in recovered_vks]
687+
self.assertIn(
688+
vk.pubkey.point,
689+
[recovered_vk.pubkey.point for recovered_vk in recovered_vks],
690690
)
691691

692692
def test_encoding(self):

0 commit comments

Comments
 (0)