Skip to content

Commit 79d7d90

Browse files
Sechenov Paveltomato42
authored andcommitted
Added __eq__ method to classes to improve comparison abilities (#161)
* Added __eq__ method to classes to improve comparison abilities
1 parent c3136e4 commit 79d7d90

File tree

7 files changed

+269
-78
lines changed

7 files changed

+269
-78
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ t/
4343
.project
4444
.pydevproject
4545

46+
#vscode
47+
.vscode
48+
4649
# Backup files
4750
*.swp
4851
*~

src/ecdsa/ecdsa.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ def __init__(self, generator, point):
119119
if point.x() < 0 or n <= point.x() or point.y() < 0 or n <= point.y():
120120
raise RuntimeError("Generator point has x or y out of range.")
121121

122+
def __eq__(self, other):
123+
if isinstance(other, Public_key):
124+
"""Return True if the points are identical, False otherwise."""
125+
return self.curve == other.curve \
126+
and self.point == other.point
127+
return NotImplemented
128+
122129
def verifies(self, hash, signature):
123130
"""Verify that signature is a valid signature of hash.
124131
Return True if the signature is valid.
@@ -153,6 +160,13 @@ def __init__(self, public_key, secret_multiplier):
153160

154161
self.public_key = public_key
155162
self.secret_multiplier = secret_multiplier
163+
164+
def __eq__(self, other):
165+
if isinstance(other, Private_key):
166+
"""Return True if the points are identical, False otherwise."""
167+
return self.public_key == other.public_key \
168+
and self.secret_multiplier == other.secret_multiplier
169+
return NotImplemented
156170

157171
def sign(self, hash, random_k):
158172
"""Return a signature for the provided hash, using the provided

src/ecdsa/ellipticcurve.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def __init__(self, p, a, b):
4545
self.__p = p
4646
self.__a = a
4747
self.__b = b
48+
49+
def __eq__(self, other):
50+
if isinstance(other, CurveFp):
51+
"""Return True if the curves are identical, False otherwise."""
52+
return self.__p == other.__p \
53+
and self.__a == other.__a \
54+
and self.__b == other.__b
55+
return NotImplemented
4856

4957
def p(self):
5058
return self.__p
@@ -79,12 +87,11 @@ def __init__(self, curve, x, y, order=None):
7987

8088
def __eq__(self, other):
8189
"""Return True if the points are identical, False otherwise."""
82-
if self.__curve == other.__curve \
83-
and self.__x == other.__x \
84-
and self.__y == other.__y:
85-
return True
86-
else:
87-
return False
90+
if isinstance(other, Point):
91+
return self.__curve == other.__curve \
92+
and self.__x == other.__x \
93+
and self.__y == other.__y
94+
return NotImplemented
8895

8996
def __neg__(self):
9097
return Point(self.__curve, self.__x, self.__curve.p() - self.__y)
@@ -195,4 +202,3 @@ def order(self):
195202

196203
# This one point is the Point At Infinity for all purposes:
197204
INFINITY = Point(None, None, None)
198-

src/ecdsa/keys.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ def __repr__(self):
139139
pub_key = self.to_string("compressed")
140140
return "VerifyingKey.from_string({0!r}, {1!r}, {2})".format(
141141
pub_key, self.curve, self.default_hashfunc().name)
142+
143+
def __eq__(self, other):
144+
"""Return True if the points are identical, False otherwise."""
145+
if isinstance(other, VerifyingKey):
146+
return self.curve == other.curve \
147+
and self.pubkey == other.pubkey
148+
return NotImplemented
142149

143150
@classmethod
144151
def from_public_point(cls, point, curve=NIST192p, hashfunc=sha1):
@@ -649,6 +656,14 @@ def __init__(self, _error__please_use_generate=None):
649656
self.baselen = None
650657
self.verifying_key = None
651658
self.privkey = None
659+
660+
def __eq__(self, other):
661+
"""Return True if the points are identical, False otherwise."""
662+
if isinstance(other, SigningKey):
663+
return self.curve == other.curve \
664+
and self.verifying_key == other.verifying_key \
665+
and self.privkey == other.privkey
666+
return NotImplemented
652667

653668
@classmethod
654669
def generate(cls, curve=NIST192p, entropy=None, hashfunc=sha1):

src/ecdsa/test_ecdsa.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,65 @@ def test_rejection(self):
6060
assert not self.pubk.verifies(self.msg - 1, self.sig)
6161

6262

63+
class TestPublicKey(unittest.TestCase):
64+
65+
def test_equality_public_keys(self):
66+
gen = generator_192
67+
x = 0xc58d61f88d905293bcd4cd0080bcb1b7f811f2ffa41979f6
68+
y = 0x8804dc7a7c4c7f8b5d437f5156f3312ca7d6de8a0e11867f
69+
point = ellipticcurve.Point(gen.curve(), x, y)
70+
pub_key1 = Public_key(gen, point)
71+
pub_key2 = Public_key(gen, point)
72+
self.assertEqual(pub_key1, pub_key2)
73+
74+
def test_inequality_public_key(self):
75+
gen = generator_192
76+
x1 = 0xc58d61f88d905293bcd4cd0080bcb1b7f811f2ffa41979f6
77+
y1 = 0x8804dc7a7c4c7f8b5d437f5156f3312ca7d6de8a0e11867f
78+
point1 = ellipticcurve.Point(gen.curve(), x1, y1)
79+
80+
x2 = 0x6a223d00bd22c52833409a163e057e5b5da1def2a197dd15
81+
y2 = 0x7b482604199367f1f303f9ef627f922f97023e90eae08abf
82+
point2 = ellipticcurve.Point(gen.curve(), x2, y2)
83+
84+
pub_key1 = Public_key(gen, point1)
85+
pub_key2 = Public_key(gen, point2)
86+
self.assertNotEqual(pub_key1, pub_key2)
87+
88+
def test_inequality_public_key_not_implemented(self):
89+
gen = generator_192
90+
x = 0xc58d61f88d905293bcd4cd0080bcb1b7f811f2ffa41979f6
91+
y = 0x8804dc7a7c4c7f8b5d437f5156f3312ca7d6de8a0e11867f
92+
point = ellipticcurve.Point(gen.curve(), x, y)
93+
pub_key = Public_key(gen, point)
94+
self.assertNotEqual(pub_key, None)
95+
96+
97+
class TestPrivateKey(unittest.TestCase):
98+
99+
@classmethod
100+
def setUpClass(cls):
101+
gen = generator_192
102+
x = 0xc58d61f88d905293bcd4cd0080bcb1b7f811f2ffa41979f6
103+
y = 0x8804dc7a7c4c7f8b5d437f5156f3312ca7d6de8a0e11867f
104+
point = ellipticcurve.Point(gen.curve(), x, y)
105+
cls.pub_key = Public_key(gen, point)
106+
107+
def test_equality_private_keys(self):
108+
pr_key1 = Private_key(self.pub_key, 100)
109+
pr_key2 = Private_key(self.pub_key, 100)
110+
self.assertEqual(pr_key1, pr_key2)
111+
112+
def test_inequality_private_keys(self):
113+
pr_key1 = Private_key(self.pub_key, 100)
114+
pr_key2 = Private_key(self.pub_key, 200)
115+
self.assertNotEqual(pr_key1, pr_key2)
116+
117+
def test_inequality_private_keys_not_implemented(self):
118+
pr_key = Private_key(self.pub_key, 100)
119+
self.assertNotEqual(pr_key, None)
120+
121+
63122
# Testing point validity, as per ECDSAVS.pdf B.2.2:
64123
P192_POINTS = [
65124
(generator_192,

src/ecdsa/test_ellipticcurve.py

Lines changed: 111 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,11 @@
3030
Gx = 0x188da80eb03090f67cbf20eb43a18800f4ff0afd82ff1012
3131
Gy = 0x07192b95ffc8da78631011ed6b24cdd573f977a11e794811
3232

33-
3433
c192 = CurveFp(p, -3, b)
3534
p192 = Point(c192, Gx, Gy, r)
3635

37-
38-
def test_p192():
39-
# Checking against some sample computations presented
40-
# in X9.62:
41-
d = 651056770906015076056810763456358567190100156695615665659
42-
Q = d * p192
43-
assert Q.x() == 0x62B12D60690CDCF330BABAB6E69763B471F994DD702D16A5
44-
45-
k = 6140507067065001063065065565667405560006161556565665656654
46-
R = k * p192
47-
assert R.x() == 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD \
48-
and R.y() == 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835
49-
50-
u1 = 2563697409189434185194736134579731015366492496392189760599
51-
u2 = 6266643813348617967186477710235785849136406323338782220568
52-
temp = u1 * p192 + u2 * Q
53-
assert temp.x() == 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD \
54-
and temp.y() == 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835
36+
c_23 = CurveFp(23, 1, 1)
37+
g_23 = Point(c_23, 13, 7, 7)
5538

5639

5740
@settings(**HYP_SETTINGS)
@@ -61,63 +44,16 @@ def test_p192_mult_tests(multiple):
6144

6245
p1 = p192 * multiple
6346
assert p1 * inv_m == p192
64-
65-
47+
48+
6649
def add_n_times(point, n):
6750
ret = INFINITY
6851
i = 0
6952
while i <= n:
7053
yield ret
7154
ret = ret + point
7255
i += 1
73-
74-
75-
c_23 = CurveFp(23, 1, 1)
76-
77-
78-
g_23 = Point(c_23, 13, 7, 7)
79-
80-
81-
# Trivial tests from X9.62 B.3:
82-
@pytest.mark.parametrize(
83-
"c,x1,y1,x2,y2,x3,y3",
84-
[(c_23, 3, 10, 9, 7, 17, 20),
85-
(c_23, 3, 10, 3, 10, 7, 12)],
86-
ids=["real add", "double"])
87-
def test_add(c, x1, y1, x2, y2, x3, y3):
88-
"""We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3)."""
89-
p1 = Point(c, x1, y1)
90-
p2 = Point(c, x2, y2)
91-
p3 = p1 + p2
92-
assert p3.x() == x3 and p3.y() == y3
93-
94-
95-
@pytest.mark.parametrize(
96-
"c, x1, y1, x3, y3",
97-
[(c_23, 3, 10, 7, 12)],
98-
ids=["real add"])
99-
def test_double(c, x1, y1, x3, y3):
100-
p1 = Point(c, x1, y1)
101-
p3 = p1.double()
102-
assert p3.x() == x3 and p3.y() == y3
103-
104-
105-
def test_double_infinity():
106-
p1 = INFINITY
107-
p3 = p1.double()
108-
assert p1 == p3
109-
assert p3.x() == p1.x() and p3.y() == p3.y()
110-
111-
112-
@pytest.mark.parametrize(
113-
"c, x1, y1, m, x3, y3",
114-
[(c_23, 3, 10, 2, 7, 12)],
115-
ids=["multiply by 2"])
116-
def test_multiply(c, x1, y1, m, x3, y3):
117-
p1 = Point(c, x1, y1)
118-
p3 = p1 * m
119-
assert p3.x() == x3 and p3.y() == y3
120-
56+
12157

12258
# From X9.62 I.1 (p. 96):
12359
@pytest.mark.parametrize(
@@ -126,3 +62,109 @@ def test_multiply(c, x1, y1, m, x3, y3):
12662
ids=["g_23 test with mult {0}".format(i) for i in range(9)])
12763
def test_add_and_mult_equivalence(p, m, check):
12864
assert p * m == check
65+
66+
67+
class TestCurve(unittest.TestCase):
68+
69+
@classmethod
70+
def setUpClass(cls):
71+
cls.c_23 = CurveFp(23, 1, 1)
72+
73+
def test_equality_curves(self):
74+
self.assertEqual(self.c_23, CurveFp(23, 1, 1))
75+
76+
def test_inequality_curves(self):
77+
c192 = CurveFp(p, -3, b)
78+
self.assertNotEqual(self.c_23, c192)
79+
80+
81+
class TestPoint(unittest.TestCase):
82+
83+
@classmethod
84+
def setUpClass(cls):
85+
cls.c_23 = CurveFp(23, 1, 1)
86+
cls.g_23 = Point(cls.c_23, 13, 7, 7)
87+
88+
p = 6277101735386680763835789423207666416083908700390324961279
89+
r = 6277101735386680763835789423176059013767194773182842284081
90+
# s = 0x3045ae6fc8422f64ed579528d38120eae12196d5
91+
# c = 0x3099d2bbbfcb2538542dcd5fb078b6ef5f3d6fe2c745de65
92+
b = 0x64210519e59c80e70fa7e9ab72243049feb8deecc146b9b1
93+
Gx = 0x188da80eb03090f67cbf20eb43a18800f4ff0afd82ff1012
94+
Gy = 0x07192b95ffc8da78631011ed6b24cdd573f977a11e794811
95+
96+
cls.c192 = CurveFp(p, -3, b)
97+
cls.p192 = Point(cls.c192, Gx, Gy, r)
98+
99+
def test_p192(self):
100+
# Checking against some sample computations presented
101+
# in X9.62:
102+
d = 651056770906015076056810763456358567190100156695615665659
103+
Q = d * self.p192
104+
self.assertEqual(Q.x(), 0x62B12D60690CDCF330BABAB6E69763B471F994DD702D16A5)
105+
106+
k = 6140507067065001063065065565667405560006161556565665656654
107+
R = k * self.p192
108+
self.assertEqual(R.x(), 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD)
109+
self.assertEqual(R.y(), 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835)
110+
111+
u1 = 2563697409189434185194736134579731015366492496392189760599
112+
u2 = 6266643813348617967186477710235785849136406323338782220568
113+
temp = u1 * self.p192 + u2 * Q
114+
self.assertEqual(temp.x(), 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD)
115+
self.assertEqual(temp.y(), 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835)
116+
117+
def test_double_infinity(self):
118+
p1 = INFINITY
119+
p3 = p1.double()
120+
self.assertEqual(p1, p3)
121+
self.assertEqual(p3.x(), p1.x())
122+
self.assertEqual(p3.y(), p3.y())
123+
124+
def test_double(self):
125+
x1, y1, x3, y3 = (3, 10, 7, 12)
126+
127+
p1 = Point(self.c_23, x1, y1)
128+
p3 = p1.double()
129+
self.assertEqual(p3.x(), x3)
130+
self.assertEqual(p3.y(), y3)
131+
132+
def test_multiply(self):
133+
x1, y1, m, x3, y3 = (3, 10, 2, 7, 12)
134+
p1 = Point(self.c_23, x1, y1)
135+
p3 = p1 * m
136+
self.assertEqual(p3.x(), x3)
137+
self.assertEqual(p3.y(), y3)
138+
139+
# Trivial tests from X9.62 B.3:
140+
def test_add(self):
141+
"""We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3)."""
142+
143+
x1, y1, x2, y2, x3, y3 = (3, 10, 9, 7, 17, 20)
144+
p1 = Point(self.c_23, x1, y1)
145+
p2 = Point(self.c_23, x2, y2)
146+
p3 = p1 + p2
147+
self.assertEqual(p3.x(), x3)
148+
self.assertEqual(p3.y(), y3)
149+
150+
def test_add_as_double(self):
151+
"""We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3)."""
152+
153+
x1, y1, x2, y2, x3, y3 = (3, 10, 3, 10, 7, 12)
154+
p1 = Point(self.c_23, x1, y1)
155+
p2 = Point(self.c_23, x2, y2)
156+
p3 = p1 + p2
157+
self.assertEqual(p3.x(), x3)
158+
self.assertEqual(p3.y(), y3)
159+
160+
def test_equality_points(self):
161+
self.assertEqual(self.g_23, Point(self.c_23, 13, 7, 7))
162+
163+
def test_inequality_points(self):
164+
c = CurveFp(100, -3, 100)
165+
p = Point(c, 100, 100, 100)
166+
self.assertNotEqual(self.g_23, p)
167+
168+
def test_inaquality_points_diff_types(self):
169+
c = CurveFp(100, -3, 100)
170+
self.assertNotEqual(self.g_23, c)

0 commit comments

Comments
 (0)