1414
1515"""Test the collection module."""
1616
17- import json
1817import os
1918import sys
2019
2120sys .path [0 :0 ] = ["" ]
2221
23- from bson .py3compat import iteritems
24- from pymongo import operations , WriteConcern
25- from pymongo .command_cursor import CommandCursor
26- from pymongo .cursor import Cursor
27- from pymongo .errors import PyMongoError
28- from pymongo .read_concern import ReadConcern
29- from pymongo .results import _WriteResult , BulkWriteResult
22+ from test import unittest
23+ from test .utils import TestCreator
24+ from test .utils_spec_runner import SpecRunner
3025
31- from test import unittest , client_context , IntegrationTest
32- from test .utils import (camel_to_snake , camel_to_upper_camel ,
33- camel_to_snake_args , drop_collections ,
34- parse_collection_options , rs_client ,
35- OvertCommandListener , TestCreator )
3626
3727# Location of JSON test specifications.
3828_TEST_PATH = os .path .join (
4333TEST_COLLECTION = 'testcollection'
4434
4535
46- class TestAllScenarios (IntegrationTest ):
47- def run_operation (self , collection , test ):
48- # Iterate over all operations.
49- for opdef in test ['operations' ]:
50- # Convert command from CamelCase to pymongo.collection method.
51- operation = camel_to_snake (opdef ['name' ])
36+ class TestSpec (SpecRunner ):
37+ def get_scenario_db_name (self , scenario_def ):
38+ """Crud spec says database_name is optional."""
39+ return scenario_def .get ('database_name' , TEST_DB )
5240
53- # Get command handle on target entity (collection/database).
54- target_object = opdef .get ('object' , 'collection' )
55- if target_object == 'database' :
56- cmd = getattr (collection .database , operation )
57- elif target_object == 'collection' :
58- collection = collection .with_options (** dict (
59- parse_collection_options (opdef .get (
60- 'collectionOptions' , {}))))
61- cmd = getattr (collection , operation )
62- else :
63- self .fail ("Unknown object name %s" % (target_object ,))
41+ def get_scenario_coll_name (self , scenario_def ):
42+ """Crud spec says collection_name is optional."""
43+ return scenario_def .get ('collection_name' , TEST_COLLECTION )
6444
65- # Convert arguments to snake_case and handle special cases.
66- arguments = opdef [ 'arguments' ]
67- options = arguments . pop ( "options" , {} )
45+ def get_object_name ( self , op ):
46+ """Crud spec says object is optional and defaults to 'collection'."""
47+ return op . get ( 'object' , 'collection' )
6848
69- for option_name in options :
70- arguments [camel_to_snake (option_name )] = options [option_name ]
71-
72- if operation == "bulk_write" :
73- # Parse each request into a bulk write model.
74- requests = []
75- for request in arguments ["requests" ]:
76- bulk_model = camel_to_upper_camel (request ["name" ])
77- bulk_class = getattr (operations , bulk_model )
78- bulk_arguments = camel_to_snake_args (request ["arguments" ])
79- requests .append (bulk_class (** bulk_arguments ))
80- arguments ["requests" ] = requests
81- else :
82- for arg_name in list (arguments ):
83- c2s = camel_to_snake (arg_name )
84- # PyMongo accepts sort as list of tuples.
85- if arg_name == "sort" :
86- sort_dict = arguments [arg_name ]
87- arguments [arg_name ] = list (iteritems (sort_dict ))
88- # Named "key" instead not fieldName.
89- if arg_name == "fieldName" :
90- arguments ["key" ] = arguments .pop (arg_name )
91- # Aggregate uses "batchSize", while find uses batch_size.
92- elif arg_name == "batchSize" and operation == "aggregate" :
93- continue
94- # Requires boolean returnDocument.
95- elif arg_name == "returnDocument" :
96- arguments [c2s ] = arguments .pop (arg_name ) == "After"
97- else :
98- arguments [c2s ] = arguments .pop (arg_name )
99-
100- if opdef .get ('error' ) is True :
101- with self .assertRaises (PyMongoError ):
102- cmd (** arguments )
103- else :
104- result = cmd (** arguments )
105- self .check_result (opdef .get ('result' ), result )
106-
107- def check_result (self , expected_result , result ):
108- if expected_result is None :
109- return True
110-
111- if isinstance (result , Cursor ) or isinstance (result , CommandCursor ):
112- return list (result ) == expected_result
113-
114- elif isinstance (result , _WriteResult ):
115- for res in expected_result :
116- prop = camel_to_snake (res )
117- # SPEC-869: Only BulkWriteResult has upserted_count.
118- if (prop == "upserted_count" and
119- not isinstance (result , BulkWriteResult )):
120- if result .upserted_id is not None :
121- upserted_count = 1
122- else :
123- upserted_count = 0
124- if upserted_count != expected_result [res ]:
125- return False
126- elif prop == "inserted_ids" :
127- # BulkWriteResult does not have inserted_ids.
128- if isinstance (result , BulkWriteResult ):
129- if len (expected_result [res ]) != result .inserted_count :
130- return False
131- else :
132- # InsertManyResult may be compared to [id1] from the
133- # crud spec or {"0": id1} from the retryable write spec.
134- ids = expected_result [res ]
135- if isinstance (ids , dict ):
136- ids = [ids [str (i )] for i in range (len (ids ))]
137- if ids != result .inserted_ids :
138- return False
139- elif prop == "upserted_ids" :
140- # Convert indexes from strings to integers.
141- ids = expected_result [res ]
142- expected_ids = {}
143- for str_index in ids :
144- expected_ids [int (str_index )] = ids [str_index ]
145- if expected_ids != result .upserted_ids :
146- return False
147- elif getattr (result , prop ) != expected_result [res ]:
148- return False
149- return True
150- else :
151- if not expected_result :
152- return result is None
153- else :
154- return result == expected_result
155-
156- def check_events (self , expected_events , listener ):
157- res = listener .results
158- if not len (expected_events ):
159- return
160-
161- # Expectations only have CommandStartedEvents.
162- self .assertEqual (len (res ['started' ]), len (expected_events ))
163- for i , expectation in enumerate (expected_events ):
164- event_type = next (iter (expectation ))
165- event = res ['started' ][i ]
166-
167- # The tests substitute 42 for any number other than 0.
168- if (event .command_name == 'getMore'
169- and event .command ['getMore' ]):
170- event .command ['getMore' ] = 42
171- elif event .command_name == 'killCursors' :
172- event .command ['cursors' ] = [42 ]
173- # Add upsert and multi fields back into expectations.
174- elif event .command_name == 'update' :
175- updates = expectation [event_type ]['command' ]['updates' ]
176- for update in updates :
177- update .setdefault ('upsert' , False )
178- update .setdefault ('multi' , False )
179-
180- # Replace afterClusterTime: 42 with actual afterClusterTime.
181- expected_cmd = expectation [event_type ]['command' ]
182- expected_read_concern = expected_cmd .get ('readConcern' )
183- if expected_read_concern is not None :
184- time = expected_read_concern .get ('afterClusterTime' )
185- if time == 42 :
186- actual_time = event .command .get (
187- 'readConcern' , {}).get ('afterClusterTime' )
188- if actual_time is not None :
189- expected_read_concern ['afterClusterTime' ] = actual_time
190-
191- for attr , expected in expectation [event_type ].items ():
192- actual = getattr (event , attr )
193- if isinstance (expected , dict ):
194- for key , val in expected .items ():
195- if val is None :
196- if key in actual :
197- self .fail ("Unexpected key [%s] in %r" % (
198- key , actual ))
199- elif key not in actual :
200- self .fail ("Expected key [%s] in %r" % (
201- key , actual ))
202- else :
203- self .assertEqual (val , actual [key ],
204- "Key [%s] in %s" % (key , actual ))
205- else :
206- self .assertEqual (actual , expected )
49+ def get_outcome_coll_name (self , outcome , collection ):
50+ """Crud spec says outcome has an optional 'collection.name'."""
51+ return outcome ['collection' ].get ('name' , collection .name )
20752
20853
20954def create_test (scenario_def , test , name ):
21055 def run_scenario (self ):
211- listener = OvertCommandListener ()
212- # New client, to avoid interference from pooled sessions.
213- # Convert test['clientOptions'] to dict to avoid a Jython bug using "**"
214- # with ScenarioDict.
215- client = rs_client (event_listeners = [listener ],
216- ** dict (test .get ('clientOptions' , {})))
217- # Close the client explicitly to avoid having too many threads open.
218- self .addCleanup (client .close )
219-
220- # Get database and collection objects.
221- database = getattr (
222- client , scenario_def .get ('database_name' , TEST_DB ))
223- drop_collections (database )
224- collection = getattr (
225- database , scenario_def .get ('collection_name' , TEST_COLLECTION ))
226-
227- # Populate collection with data and run test.
228- collection .with_options (
229- write_concern = WriteConcern (w = "majority" )).insert_many (
230- scenario_def .get ('data' , []))
231- listener .results .clear ()
232- self .run_operation (collection , test )
233-
234- # Assert expected events.
235- self .check_events (test .get ('expectations' , {}), listener )
236-
237- # Assert final state is expected.
238- expected_outcome = test .get ('outcome' , {}).get ('collection' )
239- if expected_outcome is not None :
240- collname = expected_outcome .get ('name' )
241- if collname is not None :
242- o_collection = getattr (database , collname )
243- else :
244- o_collection = collection
245- o_collection = o_collection .with_options (
246- read_concern = ReadConcern (level = "local" ))
247- self .assertEqual (list (o_collection .find ()),
248- expected_outcome ['data' ])
56+ self .run_scenario (scenario_def , test )
24957
25058 return run_scenario
25159
25260
253- test_creator = TestCreator (create_test , TestAllScenarios , _TEST_PATH )
61+ test_creator = TestCreator (create_test , TestSpec , _TEST_PATH )
25462test_creator .create_tests ()
25563
25664
25765if __name__ == "__main__" :
258- unittest .main ()
66+ unittest .main ()
0 commit comments