@@ -188,12 +188,6 @@ def _socket_for_reads(self, session):
188188 return self .__database .client ._socket_for_reads (
189189 self ._read_preference_for (session ), session )
190190
191- def _socket_for_primary_reads (self , session ):
192- read_pref = ((session and session ._txn_read_preference ())
193- or ReadPreference .PRIMARY )
194- return self .__database .client ._socket_for_reads (
195- read_pref , session ), read_pref
196-
197191 def _socket_for_writes (self , session ):
198192 return self .__database .client ._socket_for_writes (session )
199193
@@ -1572,7 +1566,7 @@ def parallel_scan(self, num_cursors, session=None, **kwargs):
15721566
15731567 def _count (self , cmd , collation = None , session = None ):
15741568 """Internal count helper."""
1575- with self . _socket_for_reads (session ) as ( sock_info , slave_ok ):
1569+ def _cmd (session , server , sock_info , slave_ok ):
15761570 res = self ._command (
15771571 sock_info ,
15781572 cmd ,
@@ -1582,9 +1576,12 @@ def _count(self, cmd, collation=None, session=None):
15821576 read_concern = self .read_concern ,
15831577 collation = collation ,
15841578 session = session )
1585- if res .get ("errmsg" , "" ) == "ns missing" :
1586- return 0
1587- return int (res ["n" ])
1579+ if res .get ("errmsg" , "" ) == "ns missing" :
1580+ return 0
1581+ return int (res ["n" ])
1582+
1583+ return self .__database .client ._retryable_read (
1584+ _cmd , self ._read_preference_for (session ), session )
15881585
15891586 def _aggregate_one_result (
15901587 self , sock_info , slave_ok , cmd , collation = None , session = None ):
@@ -1693,12 +1690,16 @@ def count_documents(self, filter, session=None, **kwargs):
16931690 kwargs ["hint" ] = helpers ._index_document (kwargs ["hint" ])
16941691 collation = validate_collation_or_none (kwargs .pop ('collation' , None ))
16951692 cmd .update (kwargs )
1696- with self ._socket_for_reads (session ) as (sock_info , slave_ok ):
1693+
1694+ def _cmd (session , server , sock_info , slave_ok ):
16971695 result = self ._aggregate_one_result (
16981696 sock_info , slave_ok , cmd , collation , session )
1699- if not result :
1700- return 0
1701- return result ['n' ]
1697+ if not result :
1698+ return 0
1699+ return result ['n' ]
1700+
1701+ return self .__database .client ._retryable_read (
1702+ _cmd , self ._read_preference_for (session ), session )
17021703
17031704 def count (self , filter = None , session = None , ** kwargs ):
17041705 """**DEPRECATED** - Get the number of documents in this collection.
@@ -2149,8 +2150,10 @@ def list_indexes(self, session=None):
21492150 codec_options = CodecOptions (SON )
21502151 coll = self .with_options (codec_options = codec_options ,
21512152 read_preference = ReadPreference .PRIMARY )
2152- sock_ctx , read_pref = self ._socket_for_primary_reads (session )
2153- with sock_ctx as (sock_info , slave_ok ):
2153+ read_pref = ((session and session ._txn_read_preference ())
2154+ or ReadPreference .PRIMARY )
2155+
2156+ def _cmd (session , server , sock_info , slave_ok ):
21542157 cmd = SON ([("listIndexes" , self .__name ), ("cursor" , {})])
21552158 if sock_info .max_wire_version > 2 :
21562159 with self .__database .client ._tmp_session (session , False ) as s :
@@ -2179,6 +2182,9 @@ def list_indexes(self, session=None):
21792182 # will never be a getMore call.
21802183 return CommandCursor (coll , cursor , sock_info .address )
21812184
2185+ return self .__database .client ._retryable_read (
2186+ _cmd , read_pref , session )
2187+
21822188 def index_information (self , session = None ):
21832189 """Get information on this collection's indexes.
21842190
@@ -2275,10 +2281,11 @@ def _aggregate(self, pipeline, cursor_class, first_batch_size, session,
22752281 "useCursor" , kwargs .pop ("useCursor" ))
22762282 batch_size = common .validate_non_negative_integer_or_none (
22772283 "batchSize" , kwargs .pop ("batchSize" , None ))
2284+
2285+ dollar_out = pipeline and '$out' in pipeline [- 1 ]
22782286 # If the server does not support the "cursor" option we
22792287 # ignore useCursor and batchSize.
2280- with self ._socket_for_reads (session ) as (sock_info , slave_ok ):
2281- dollar_out = pipeline and '$out' in pipeline [- 1 ]
2288+ def _cmd (session , server , sock_info , slave_ok ):
22822289 if use_cursor :
22832290 if "cursor" not in kwargs :
22842291 kwargs ["cursor" ] = {}
@@ -2336,6 +2343,10 @@ def _aggregate(self, pipeline, cursor_class, first_batch_size, session,
23362343 max_await_time_ms = max_await_time_ms ,
23372344 session = session , explicit_session = explicit_session )
23382345
2346+ return self .__database .client ._retryable_read (
2347+ _cmd , self ._read_preference_for (session ), session ,
2348+ retryable = not dollar_out )
2349+
23392350 def aggregate (self , pipeline , session = None , ** kwargs ):
23402351 """Perform an aggregation using the aggregation framework on this
23412352 collection.
@@ -2681,12 +2692,53 @@ def distinct(self, key, filter=None, session=None, **kwargs):
26812692 kwargs ["query" ] = filter
26822693 collation = validate_collation_or_none (kwargs .pop ('collation' , None ))
26832694 cmd .update (kwargs )
2684- with self ._socket_for_reads (session ) as (sock_info , slave_ok ):
2685- return self ._command (sock_info , cmd , slave_ok ,
2686- read_concern = self .read_concern ,
2687- collation = collation ,
2688- session = session ,
2689- user_fields = {"values" : 1 })["values" ]
2695+ def _cmd (session , server , sock_info , slave_ok ):
2696+ return self ._command (
2697+ sock_info , cmd , slave_ok , read_concern = self .read_concern ,
2698+ collation = collation , session = session ,
2699+ user_fields = {"values" : 1 })["values" ]
2700+
2701+ return self .__database .client ._retryable_read (
2702+ _cmd , self ._read_preference_for (session ), session )
2703+
2704+ def _map_reduce (self , map , reduce , out , session , read_pref , ** kwargs ):
2705+ """Internal mapReduce helper."""
2706+ cmd = SON ([("mapReduce" , self .__name ),
2707+ ("map" , map ),
2708+ ("reduce" , reduce ),
2709+ ("out" , out )])
2710+ collation = validate_collation_or_none (kwargs .pop ('collation' , None ))
2711+ cmd .update (kwargs )
2712+
2713+ inline = 'inline' in out
2714+
2715+ if inline :
2716+ user_fields = {'results' : 1 }
2717+ else :
2718+ user_fields = None
2719+
2720+ read_pref = ((session and session ._txn_read_preference ())
2721+ or read_pref )
2722+
2723+ with self .__database .client ._socket_for_reads (read_pref , session ) as (
2724+ sock_info , slave_ok ):
2725+ if (sock_info .max_wire_version >= 4 and
2726+ ('readConcern' not in cmd ) and
2727+ inline ):
2728+ read_concern = self .read_concern
2729+ else :
2730+ read_concern = None
2731+ if 'writeConcern' not in cmd and not inline :
2732+ write_concern = self ._write_concern_for (session )
2733+ else :
2734+ write_concern = None
2735+
2736+ return self ._command (
2737+ sock_info , cmd , slave_ok , read_pref ,
2738+ read_concern = read_concern ,
2739+ write_concern = write_concern ,
2740+ collation = collation , session = session ,
2741+ user_fields = user_fields )
26902742
26912743 def map_reduce (self , map , reduce , out , full_response = False , session = None ,
26922744 ** kwargs ):
@@ -2747,36 +2799,8 @@ def map_reduce(self, map, reduce, out, full_response=False, session=None,
27472799 raise TypeError ("'out' must be an instance of "
27482800 "%s or a mapping" % (string_type .__name__ ,))
27492801
2750- cmd = SON ([("mapreduce" , self .__name ),
2751- ("map" , map ),
2752- ("reduce" , reduce ),
2753- ("out" , out )])
2754- collation = validate_collation_or_none (kwargs .pop ('collation' , None ))
2755- cmd .update (kwargs )
2756-
2757- inline = 'inline' in cmd ['out' ]
2758- sock_ctx , read_pref = self ._socket_for_primary_reads (session )
2759- with sock_ctx as (sock_info , slave_ok ):
2760- if (sock_info .max_wire_version >= 4 and 'readConcern' not in cmd and
2761- inline ):
2762- read_concern = self .read_concern
2763- else :
2764- read_concern = None
2765- if 'writeConcern' not in cmd and not inline :
2766- write_concern = self ._write_concern_for (session )
2767- else :
2768- write_concern = None
2769- if inline :
2770- user_fields = {'results' : 1 }
2771- else :
2772- user_fields = None
2773-
2774- response = self ._command (
2775- sock_info , cmd , slave_ok , read_pref ,
2776- read_concern = read_concern ,
2777- write_concern = write_concern ,
2778- collation = collation , session = session ,
2779- user_fields = user_fields )
2802+ response = self ._map_reduce (map , reduce , out , session ,
2803+ ReadPreference .PRIMARY , ** kwargs )
27802804
27812805 if full_response or not response .get ('result' ):
27822806 return response
@@ -2822,23 +2846,8 @@ def inline_map_reduce(self, map, reduce, full_response=False, session=None,
28222846 Added the `collation` option.
28232847
28242848 """
2825- cmd = SON ([("mapreduce" , self .__name ),
2826- ("map" , map ),
2827- ("reduce" , reduce ),
2828- ("out" , {"inline" : 1 })])
2829- user_fields = {'results' : 1 }
2830- collation = validate_collation_or_none (kwargs .pop ('collation' , None ))
2831- cmd .update (kwargs )
2832- with self ._socket_for_reads (session ) as (sock_info , slave_ok ):
2833- if sock_info .max_wire_version >= 4 and 'readConcern' not in cmd :
2834- res = self ._command (sock_info , cmd , slave_ok ,
2835- read_concern = self .read_concern ,
2836- collation = collation , session = session ,
2837- user_fields = user_fields )
2838- else :
2839- res = self ._command (sock_info , cmd , slave_ok ,
2840- collation = collation , session = session ,
2841- user_fields = user_fields )
2849+ res = self ._map_reduce (map , reduce , {"inline" : 1 }, session ,
2850+ self .read_preference , ** kwargs )
28422851
28432852 if full_response :
28442853 return res
0 commit comments