@@ -254,9 +254,22 @@ def _parse_scram_response(response):
254254 return dict (item .split (b"=" , 1 ) for item in response .split (b"," ))
255255
256256
257+ def _authenticate_scram_start (credentials , mechanism ):
258+ username = credentials .username
259+ user = username .encode ("utf-8" ).replace (b"=" , b"=3D" ).replace (b"," , b"=2C" )
260+ nonce = standard_b64encode (os .urandom (32 ))
261+ first_bare = b"n=" + user + b",r=" + nonce
262+
263+ cmd = SON ([('saslStart' , 1 ),
264+ ('mechanism' , mechanism ),
265+ ('payload' , Binary (b"n,," + first_bare )),
266+ ('autoAuthorize' , 1 ),
267+ ('options' , {'skipEmptyExchange' : True })])
268+ return nonce , first_bare , cmd
269+
270+
257271def _authenticate_scram (credentials , sock_info , mechanism ):
258272 """Authenticate using SCRAM."""
259-
260273 username = credentials .username
261274 if mechanism == 'SCRAM-SHA-256' :
262275 digest = "sha256"
@@ -272,16 +285,14 @@ def _authenticate_scram(credentials, sock_info, mechanism):
272285 # Make local
273286 _hmac = hmac .HMAC
274287
275- user = username .encode ("utf-8" ).replace (b"=" , b"=3D" ).replace (b"," , b"=2C" )
276- nonce = standard_b64encode (os .urandom (32 ))
277- first_bare = b"n=" + user + b",r=" + nonce
278-
279- cmd = SON ([('saslStart' , 1 ),
280- ('mechanism' , mechanism ),
281- ('payload' , Binary (b"n,," + first_bare )),
282- ('autoAuthorize' , 1 ),
283- ('options' , {'skipEmptyExchange' : True })])
284- res = sock_info .command (source , cmd )
288+ ctx = sock_info .auth_ctx .get (credentials )
289+ if ctx and ctx .speculate_succeeded ():
290+ nonce , first_bare = ctx .scram_data
291+ res = ctx .speculative_authenticate
292+ else :
293+ nonce , first_bare , cmd = _authenticate_scram_start (
294+ credentials , mechanism )
295+ res = sock_info .command (source , cmd )
285296
286297 server_first = res ['payload' ]
287298 parsed = _parse_scram_response (server_first )
@@ -516,15 +527,17 @@ def _authenticate_cram_md5(credentials, sock_info):
516527def _authenticate_x509 (credentials , sock_info ):
517528 """Authenticate using MONGODB-X509.
518529 """
519- query = SON ([('authenticate' , 1 ),
520- ('mechanism' , 'MONGODB-X509' )])
521- if credentials .username is not None :
522- query ['user' ] = credentials .username
523- elif sock_info .max_wire_version < 5 :
530+ ctx = sock_info .auth_ctx .get (credentials )
531+ if ctx and ctx .speculate_succeeded ():
532+ # MONGODB-X509 is done after the speculative auth step.
533+ return
534+
535+ cmd = _X509Context (credentials ).speculate_command ()
536+ if credentials .username is None and sock_info .max_wire_version < 5 :
524537 raise ConfigurationError (
525538 "A username is required for MONGODB-X509 authentication "
526539 "when connected to MongoDB versions older than 3.4." )
527- sock_info .command ('$external' , query )
540+ sock_info .command ('$external' , cmd )
528541
529542
530543def _authenticate_aws (credentials , sock_info ):
@@ -597,6 +610,62 @@ def _authenticate_default(credentials, sock_info):
597610}
598611
599612
613+ class _AuthContext (object ):
614+ def __init__ (self , credentials ):
615+ self .credentials = credentials
616+ self .speculative_authenticate = None
617+
618+ @staticmethod
619+ def from_credentials (creds ):
620+ spec_cls = _SPECULATIVE_AUTH_MAP .get (creds .mechanism )
621+ if spec_cls :
622+ return spec_cls (creds )
623+ return None
624+
625+ def speculate_command (self ):
626+ raise NotImplementedError
627+
628+ def parse_response (self , ismaster ):
629+ self .speculative_authenticate = ismaster .speculative_authenticate
630+
631+ def speculate_succeeded (self ):
632+ return bool (self .speculative_authenticate )
633+
634+
635+ class _ScramContext (_AuthContext ):
636+ def __init__ (self , credentials , mechanism ):
637+ super (_ScramContext , self ).__init__ (credentials )
638+ self .scram_data = None
639+ self .mechanism = mechanism
640+
641+ def speculate_command (self ):
642+ nonce , first_bare , cmd = _authenticate_scram_start (
643+ self .credentials , self .mechanism )
644+ # The 'db' field is included only on the speculative command.
645+ cmd ['db' ] = self .credentials .source
646+ # Save for later use.
647+ self .scram_data = (nonce , first_bare )
648+ return cmd
649+
650+
651+ class _X509Context (_AuthContext ):
652+ def speculate_command (self ):
653+ cmd = SON ([('authenticate' , 1 ),
654+ ('mechanism' , 'MONGODB-X509' )])
655+ if self .credentials .username is not None :
656+ cmd ['user' ] = self .credentials .username
657+ return cmd
658+
659+
660+ _SPECULATIVE_AUTH_MAP = {
661+ 'MONGODB-X509' : _X509Context ,
662+ 'SCRAM-SHA-1' : functools .partial (_ScramContext , mechanism = 'SCRAM-SHA-1' ),
663+ 'SCRAM-SHA-256' : functools .partial (_ScramContext ,
664+ mechanism = 'SCRAM-SHA-256' ),
665+ 'DEFAULT' : functools .partial (_ScramContext , mechanism = 'SCRAM-SHA-256' ),
666+ }
667+
668+
600669def authenticate (credentials , sock_info ):
601670 """Authenticate sock_info."""
602671 mechanism = credentials .mechanism
0 commit comments