@@ -155,6 +155,12 @@ def _get_md_alg(digestmod):
155155 raise TypeError (" a valid digestmod is required, got %r " % digestmod)
156156
157157
158+ class CipherState (enum .Flag ):
159+ UNSET = enum .auto ()
160+ PUBLIC = enum .auto ()
161+ PRIVATE = enum .auto ()
162+
163+
158164cdef class CipherBase:
159165 """ Base class to RSA and ECC ciphers.
160166
@@ -168,6 +174,7 @@ cdef class CipherBase:
168174 name ,
169175 const unsigned char[:] key = None ,
170176 const unsigned char[:] password = None ):
177+ self .__state = CipherState.UNSET
171178 _exc.check_error(_pk.mbedtls_pk_setup(
172179 & self ._ctx,
173180 _pk.mbedtls_pk_info_from_type(
@@ -189,6 +196,11 @@ cdef class CipherBase:
189196 except _exc.TLSError:
190197 _exc.check_error(_pk.mbedtls_pk_parse_public_key(
191198 & self ._ctx, & key[0 ], key.size))
199+ pub = self ._public_to_PEM()
200+ if " PUBLIC" in pub:
201+ self .__state |= CipherState.PUBLIC
202+ if _pk.mbedtls_pk_check_pair(& self ._ctx, & self ._ctx) == 0 :
203+ self .__state |= CipherState.PRIVATE
192204
193205 def __cinit__ (self ):
194206 """ Initialize the context."""
@@ -256,8 +268,10 @@ cdef class CipherBase:
256268 # PEM must be null-terminated.
257269 bkey = bkey + b" \0"
258270 if callable (password):
259- return cls (key = bkey, password = password())
260- return cls (key = bkey, password = password)
271+ self = cls (key = bkey, password = password())
272+ else :
273+ self = cls (key = bkey, password = password)
274+ return self
261275
262276 @classmethod
263277 def from_file (cls , path , password = None ):
@@ -292,13 +306,19 @@ cdef class CipherBase:
292306 """ Return the size of the key, in bytes."""
293307 return _pk.mbedtls_pk_get_len(& self ._ctx)
294308
309+ def _set_private (self ):
310+ self .__state |= CipherState.PRIVATE
311+
295312 def _has_private (self ):
296313 """ Return `True` if the key contains a valid private half."""
297- raise NotImplementedError
314+ return CipherState.PRIVATE in self .__state
315+
316+ def _set_public (self ):
317+ self .__state |= CipherState.PUBLIC
298318
299319 def _has_public (self ):
300320 """ Return `True` if the key contains a valid public half."""
301- raise NotImplementedError
321+ return CipherState.PUBLIC in self .__state
302322
303323 def sign (self ,
304324 const unsigned char[:] message not None ,
@@ -418,8 +438,6 @@ cdef class CipherBase:
418438 raise NotImplementedError
419439
420440 def _private_to_DER (self ):
421- if not self ._has_private():
422- return b" "
423441 cdef int olen
424442 cdef size_t osize = PRV_DER_MAX_BYTES
425443 cdef unsigned char * output = < unsigned char * > malloc(
@@ -430,12 +448,15 @@ cdef class CipherBase:
430448 olen = _exc.check_error(
431449 _pk.mbedtls_pk_write_key_der(& self ._ctx, output, osize))
432450 return output[osize - olen:osize]
451+ except _exc.TLSError as exc:
452+ if exc.err == 0x4080 :
453+ # no private key
454+ return b" "
455+ raise
433456 finally :
434457 free(output)
435458
436459 def _private_to_PEM (self ):
437- if not self ._has_private():
438- return " "
439460 cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100
440461 cdef unsigned char * output = < unsigned char * > malloc(
441462 osize * sizeof(unsigned char ))
@@ -446,6 +467,11 @@ cdef class CipherBase:
446467 _exc.check_error(
447468 _pk.mbedtls_pk_write_key_pem(& self ._ctx, output, osize))
448469 return output[0 :osize].rstrip(b" \0" ).decode(" ascii" )
470+ except _exc.TLSError as exc:
471+ if exc.err == 0x4080 :
472+ # no private key
473+ return " "
474+ raise
449475 finally :
450476 free(output)
451477
@@ -459,14 +485,12 @@ cdef class CipherBase:
459485
460486 """
461487 if format == " DER" :
462- return self ._private_to_DER()
488+ return self ._private_to_DER() if self ._has_private() else b " "
463489 if format == " PEM" :
464- return self ._private_to_PEM()
490+ return self ._private_to_PEM() if self ._has_private() else " "
465491 raise ValueError (format)
466492
467493 def _public_to_DER (self ):
468- if not self ._has_public():
469- return b" "
470494 cdef int olen
471495 cdef size_t osize = PRV_DER_MAX_BYTES
472496 cdef unsigned char * output = < unsigned char * > malloc(
@@ -481,8 +505,6 @@ cdef class CipherBase:
481505 free(output)
482506
483507 def _public_to_PEM (self ):
484- if not self ._has_public():
485- return " "
486508 cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100
487509 cdef unsigned char * output = < unsigned char * > malloc(
488510 osize * sizeof(unsigned char ))
@@ -506,9 +528,9 @@ cdef class CipherBase:
506528
507529 """
508530 if format == " DER" :
509- return self ._public_to_DER()
531+ return self ._public_to_DER() if self ._has_public() else b " "
510532 if format == " PEM" :
511- return self ._public_to_PEM()
533+ return self ._public_to_PEM() if self ._has_public() else " "
512534 raise ValueError (format)
513535
514536
@@ -521,16 +543,6 @@ cdef class RSA(CipherBase):
521543 const unsigned char[:] password = None ):
522544 super ().__init__(b" RSA" , key, password)
523545
524- def _has_private (self ):
525- """ Return `True` if the key contains a valid private half."""
526- return _rsa.mbedtls_rsa_check_privkey(
527- _pk.mbedtls_pk_rsa(self ._ctx)
528- ) == 0
529-
530- def _has_public (self ):
531- """ Return `True` if the key contains a valid public half."""
532- return _rsa.mbedtls_rsa_check_pubkey(_pk.mbedtls_pk_rsa(self ._ctx)) == 0
533-
534546 def generate (self , unsigned int key_size = 2048 , int exponent = 65537 ):
535547 """ Generate an RSA keypair.
536548
@@ -545,6 +557,8 @@ cdef class RSA(CipherBase):
545557 _exc.check_error(_rsa.mbedtls_rsa_gen_key(
546558 _pk.mbedtls_pk_rsa(self ._ctx), & _rnd.mbedtls_ctr_drbg_random,
547559 & __rng._ctx, key_size, exponent))
560+ self ._set_public()
561+ self ._set_private()
548562 return self .export_key(" DER" )
549563
550564
@@ -643,16 +657,6 @@ cdef class ECC(CipherBase):
643657 def curve (self ):
644658 return self ._curve
645659
646- def _has_private (self ):
647- """ Return `True` if the key contains a valid private half."""
648- cdef const _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self ._ctx)
649- return _mpi.mbedtls_mpi_cmp_mpi(& ecp.d, & _mpi.MPI()._ctx) != 0
650-
651- def _has_public (self ):
652- """ Return `True` if the key contains a valid public half."""
653- cdef _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self ._ctx)
654- return not _ecp.mbedtls_ecp_is_zero(& ecp.Q)
655-
656660 def sign (self ,
657661 const unsigned char[:] message not None ,
658662 digestmod = None ):
@@ -678,6 +682,8 @@ cdef class ECC(CipherBase):
678682 if self .curve in (Curve.CURVE25519, Curve.CURVE448)
679683 else " DER"
680684 )
685+ self ._set_public()
686+ self ._set_private()
681687 return self .export_key(format)
682688
683689 def _private_to_num (self ):
0 commit comments