88from hashlib import md5 as hashlib_md5 # for MD5 authentication
99
1010
11+ include " scram.pyx"
12+
13+
1114cdef class CoreProtocol:
1215
1316 def __init__ (self , con_params ):
@@ -21,6 +24,8 @@ cdef class CoreProtocol:
2124 self .state = PROTOCOL_IDLE
2225 self .xact_status = PQTRANS_IDLE
2326 self .encoding = ' utf-8'
27+ # type of `scram` is `SCRAMAuthentcation`
28+ self .scram = None
2429
2530 # executemany support data
2631 self ._execute_iter = None
@@ -528,6 +533,8 @@ cdef class CoreProtocol:
528533 cdef:
529534 int32_t status
530535 bytes md5_salt
536+ list sasl_auth_methods
537+ list unsupported_sasl_auth_methods
531538
532539 status = self .buffer.read_int32()
533540
@@ -546,6 +553,58 @@ cdef class CoreProtocol:
546553 md5_salt = self .buffer.read_bytes(4 )
547554 self .auth_msg = self ._auth_password_message_md5(md5_salt)
548555
556+ elif status == AUTH_REQUIRED_SASL:
557+ # AuthenticationSASL
558+ # This requires making additional requests to the server in order
559+ # to follow the SCRAM protocol defined in RFC 5802.
560+ # get the SASL authentication methods that the server is providing
561+ sasl_auth_methods = []
562+ unsupported_sasl_auth_methods = []
563+ # determine if the advertised authentication methods are supported,
564+ # and if so, add them to the list
565+ auth_method = self .buffer.read_null_str()
566+ while auth_method:
567+ if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS:
568+ sasl_auth_methods.append(auth_method)
569+ else :
570+ unsupported_sasl_auth_methods.append(auth_method)
571+ auth_method = self .buffer.read_null_str()
572+
573+ # if none of the advertised authentication methods are supported,
574+ # raise an error
575+ # otherwise, initialize the SASL authentication exchange
576+ if not sasl_auth_methods:
577+ unsupported_sasl_auth_methods = [m.decode(" ascii" )
578+ for m in unsupported_sasl_auth_methods]
579+ self .result_type = RESULT_FAILED
580+ self .result = apg_exc.InterfaceError(
581+ ' unsupported SASL Authentication methods requested by the '
582+ ' server: {!r}' .format(
583+ " , " .join(unsupported_sasl_auth_methods)))
584+ else :
585+ self .auth_msg = self ._auth_password_message_sasl_initial(
586+ sasl_auth_methods)
587+
588+ elif status == AUTH_SASL_CONTINUE:
589+ # AUTH_SASL_CONTINUE
590+ # this requeires sending the second part of the SASL exchange, where
591+ # the client parses information back from the server and determines
592+ # if this is valid.
593+ # The client builds a challenge response to the server
594+ server_response = self .buffer.consume_message()
595+ self .auth_msg = self ._auth_password_message_sasl_continue(
596+ server_response)
597+
598+ elif status == AUTH_SASL_FINAL:
599+ # AUTH_SASL_FINAL
600+ server_response = self .buffer.consume_message()
601+ if not self .scram.verify_server_final_message(server_response):
602+ self .result_type = RESULT_FAILED
603+ self .result = apg_exc.InterfaceError(
604+ ' could not verify server signature for '
605+ ' SCRAM authentciation: scram-sha-256' ,
606+ )
607+
549608 elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
550609 AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
551610 AUTH_REQUIRED_SSPI):
@@ -560,7 +619,8 @@ cdef class CoreProtocol:
560619 ' unsupported authentication method requested by the '
561620 ' server: {}' .format(status))
562621
563- self .buffer.discard_message()
622+ if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
623+ self .buffer.discard_message()
564624
565625 cdef _auth_password_message_cleartext(self ):
566626 cdef:
@@ -588,6 +648,34 @@ cdef class CoreProtocol:
588648
589649 return msg
590650
651+ cdef _auth_password_message_sasl_initial(self , list sasl_auth_methods):
652+ cdef:
653+ WriteBuffer msg
654+
655+ # use the first supported advertized mechanism
656+ self .scram = SCRAMAuthentication(sasl_auth_methods[0 ])
657+ # this involves a call and response with the server
658+ msg = WriteBuffer.new_message(b' p' )
659+ msg.write_bytes(self .scram.create_client_first_message(self .user or ' ' ))
660+ msg.end_message()
661+
662+ return msg
663+
664+ cdef _auth_password_message_sasl_continue(self , bytes server_response):
665+ cdef:
666+ WriteBuffer msg
667+
668+ # determine if there is a valid server response
669+ self .scram.parse_server_first_message(server_response)
670+ # this involves a call and response with the server
671+ msg = WriteBuffer.new_message(b' p' )
672+ client_final_message = self .scram.create_client_final_message(
673+ self .password or ' ' )
674+ msg.write_bytes(client_final_message)
675+ msg.end_message()
676+
677+ return msg
678+
591679 cdef _parse_msg_ready_for_query(self ):
592680 cdef char status = self .buffer.read_byte()
593681
0 commit comments