Skip to content

Commit dae0cbb

Browse files
committed
pk: Track cipher content internally
mbedtls-3.x
1 parent 3068d97 commit dae0cbb

File tree

3 files changed

+45
-36
lines changed

3 files changed

+45
-36
lines changed

src/mbedtls/pk.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ cdef extern from "mbedtls/pk.h" nogil:
108108

109109
cdef class CipherBase:
110110
cdef mbedtls_pk_context _ctx
111+
cdef object __state
111112

112113

113114
cdef class RSA(CipherBase):

src/mbedtls/pk.pyx

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ def _get_md_alg(digestmod):
155155
raise TypeError("a valid digestmod is required, got %r" % digestmod)
156156

157157

158+
class CipherState(enum.Flag):
159+
UNSET = enum.auto()
160+
PUBLIC = enum.auto()
161+
PRIVATE = enum.auto()
162+
163+
158164
cdef class CipherBase:
159165
"""Base class to RSA and ECC ciphers.
160166
@@ -168,6 +174,7 @@ cdef class CipherBase:
168174
name,
169175
const unsigned char[:] key=None,
170176
const unsigned char[:] password=None):
177+
self.__state = CipherState.UNSET
171178
_exc.check_error(_pk.mbedtls_pk_setup(
172179
&self._ctx,
173180
_pk.mbedtls_pk_info_from_type(
@@ -189,6 +196,11 @@ cdef class CipherBase:
189196
except _exc.TLSError:
190197
_exc.check_error(_pk.mbedtls_pk_parse_public_key(
191198
&self._ctx, &key[0], key.size))
199+
pub = self._public_to_PEM()
200+
if "PUBLIC" in pub:
201+
self.__state |= CipherState.PUBLIC
202+
if _pk.mbedtls_pk_check_pair(&self._ctx, &self._ctx) == 0:
203+
self.__state |= CipherState.PRIVATE
192204

193205
def __cinit__(self):
194206
"""Initialize the context."""
@@ -256,8 +268,10 @@ cdef class CipherBase:
256268
# PEM must be null-terminated.
257269
bkey = bkey + b"\0"
258270
if callable(password):
259-
return cls(key=bkey, password=password())
260-
return cls(key=bkey, password=password)
271+
self = cls(key=bkey, password=password())
272+
else:
273+
self = cls(key=bkey, password=password)
274+
return self
261275

262276
@classmethod
263277
def from_file(cls, path, password=None):
@@ -292,13 +306,19 @@ cdef class CipherBase:
292306
"""Return the size of the key, in bytes."""
293307
return _pk.mbedtls_pk_get_len(&self._ctx)
294308

309+
def _set_private(self):
310+
self.__state |= CipherState.PRIVATE
311+
295312
def _has_private(self):
296313
"""Return `True` if the key contains a valid private half."""
297-
raise NotImplementedError
314+
return CipherState.PRIVATE in self.__state
315+
316+
def _set_public(self):
317+
self.__state |= CipherState.PUBLIC
298318

299319
def _has_public(self):
300320
"""Return `True` if the key contains a valid public half."""
301-
raise NotImplementedError
321+
return CipherState.PUBLIC in self.__state
302322

303323
def sign(self,
304324
const unsigned char[:] message not None,
@@ -418,8 +438,6 @@ cdef class CipherBase:
418438
raise NotImplementedError
419439

420440
def _private_to_DER(self):
421-
if not self._has_private():
422-
return b""
423441
cdef int olen
424442
cdef size_t osize = PRV_DER_MAX_BYTES
425443
cdef unsigned char *output = <unsigned char *>malloc(
@@ -430,12 +448,15 @@ cdef class CipherBase:
430448
olen = _exc.check_error(
431449
_pk.mbedtls_pk_write_key_der(&self._ctx, output, osize))
432450
return output[osize - olen:osize]
451+
except _exc.TLSError as exc:
452+
if exc.err == 0x4080:
453+
# no private key
454+
return b""
455+
raise
433456
finally:
434457
free(output)
435458

436459
def _private_to_PEM(self):
437-
if not self._has_private():
438-
return ""
439460
cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100
440461
cdef unsigned char *output = <unsigned char *>malloc(
441462
osize * sizeof(unsigned char))
@@ -446,6 +467,11 @@ cdef class CipherBase:
446467
_exc.check_error(
447468
_pk.mbedtls_pk_write_key_pem(&self._ctx, output, osize))
448469
return output[0:osize].rstrip(b"\0").decode("ascii")
470+
except _exc.TLSError as exc:
471+
if exc.err == 0x4080:
472+
# no private key
473+
return ""
474+
raise
449475
finally:
450476
free(output)
451477

@@ -459,14 +485,12 @@ cdef class CipherBase:
459485
460486
"""
461487
if format == "DER":
462-
return self._private_to_DER()
488+
return self._private_to_DER() if self._has_private() else b""
463489
if format == "PEM":
464-
return self._private_to_PEM()
490+
return self._private_to_PEM() if self._has_private() else ""
465491
raise ValueError(format)
466492

467493
def _public_to_DER(self):
468-
if not self._has_public():
469-
return b""
470494
cdef int olen
471495
cdef size_t osize = PRV_DER_MAX_BYTES
472496
cdef unsigned char *output = <unsigned char *>malloc(
@@ -481,8 +505,6 @@ cdef class CipherBase:
481505
free(output)
482506

483507
def _public_to_PEM(self):
484-
if not self._has_public():
485-
return ""
486508
cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100
487509
cdef unsigned char *output = <unsigned char *>malloc(
488510
osize * sizeof(unsigned char))
@@ -506,9 +528,9 @@ cdef class CipherBase:
506528
507529
"""
508530
if format == "DER":
509-
return self._public_to_DER()
531+
return self._public_to_DER() if self._has_public() else b""
510532
if format == "PEM":
511-
return self._public_to_PEM()
533+
return self._public_to_PEM() if self._has_public() else ""
512534
raise ValueError(format)
513535

514536

@@ -521,16 +543,6 @@ cdef class RSA(CipherBase):
521543
const unsigned char[:] password=None):
522544
super().__init__(b"RSA", key, password)
523545

524-
def _has_private(self):
525-
"""Return `True` if the key contains a valid private half."""
526-
return _rsa.mbedtls_rsa_check_privkey(
527-
_pk.mbedtls_pk_rsa(self._ctx)
528-
) == 0
529-
530-
def _has_public(self):
531-
"""Return `True` if the key contains a valid public half."""
532-
return _rsa.mbedtls_rsa_check_pubkey(_pk.mbedtls_pk_rsa(self._ctx)) == 0
533-
534546
def generate(self, unsigned int key_size=2048, int exponent=65537):
535547
"""Generate an RSA keypair.
536548
@@ -545,6 +557,8 @@ cdef class RSA(CipherBase):
545557
_exc.check_error(_rsa.mbedtls_rsa_gen_key(
546558
_pk.mbedtls_pk_rsa(self._ctx), &_rnd.mbedtls_ctr_drbg_random,
547559
&__rng._ctx, key_size, exponent))
560+
self._set_public()
561+
self._set_private()
548562
return self.export_key("DER")
549563

550564

@@ -643,16 +657,6 @@ cdef class ECC(CipherBase):
643657
def curve(self):
644658
return self._curve
645659

646-
def _has_private(self):
647-
"""Return `True` if the key contains a valid private half."""
648-
cdef const _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx)
649-
return _mpi.mbedtls_mpi_cmp_mpi(&ecp.d, &_mpi.MPI()._ctx) != 0
650-
651-
def _has_public(self):
652-
"""Return `True` if the key contains a valid public half."""
653-
cdef _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx)
654-
return not _ecp.mbedtls_ecp_is_zero(&ecp.Q)
655-
656660
def sign(self,
657661
const unsigned char[:] message not None,
658662
digestmod=None):
@@ -678,6 +682,8 @@ cdef class ECC(CipherBase):
678682
if self.curve in (Curve.CURVE25519, Curve.CURVE448)
679683
else "DER"
680684
)
685+
self._set_public()
686+
self._set_private()
681687
return self.export_key(format)
682688

683689
def _private_to_num(self):

tests/test_pk.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def test_import_private_key(
193193

194194
other = copy(cipher)
195195
other = type(cipher).from_buffer(key)
196+
assert bytes(other) == key
197+
196198
assert other == cipher
197199
assert other.export_key() == cipher.export_key() == key
198200
assert other.export_public_key() == cipher.export_public_key()

0 commit comments

Comments
 (0)