@@ -255,6 +255,10 @@ def _handle_placeholders(self, spec: dict, current: dict, path: str) -> Any:
255255 raise ValueError (f"Could not find a placeholder value for { path } " )
256256 return PLACEHOLDER_MAP [path ]
257257
258+ # Distinguish between temp and non-temp aws credentials.
259+ if path .endswith ("/kmsProviders/aws" ) and "sessionToken" in current :
260+ path = path .replace ("aws" , "aws_temp" )
261+
258262 for key in list (current ):
259263 value = current [key ]
260264 if isinstance (value , dict ):
@@ -275,10 +279,8 @@ async def _create_entity(self, entity_spec, uri=None):
275279 if "autoEncryptOpts" in spec :
276280 auto_encrypt_opts = spec ["autoEncryptOpts" ].copy ()
277281 auto_encrypt_kwargs : dict = dict (kms_tls_options = DEFAULT_KMS_TLS )
278- kms_providers = ALL_KMS_PROVIDERS .copy ()
282+ kms_providers = auto_encrypt_opts . pop ( "kmsProviders" , ALL_KMS_PROVIDERS .copy () )
279283 key_vault_namespace = auto_encrypt_opts .pop ("keyVaultNamespace" )
280- for provider_name , provider_value in auto_encrypt_opts .pop ("kmsProviders" ).items ():
281- kms_providers [provider_name ].update (provider_value )
282284 extra_opts = auto_encrypt_opts .pop ("extraOptions" , {})
283285 for key , value in extra_opts .items ():
284286 auto_encrypt_kwargs [camel_to_snake (key )] = value
@@ -552,22 +554,25 @@ async def asyncSetUp(self):
552554
553555 def maybe_skip_test (self , spec ):
554556 # add any special-casing for skipping tests here
555- if "Client side error in command starting transaction" in spec ["description" ]:
557+ class_name = self .__class__ .__name__ .lower ()
558+ description = spec ["description" ].lower ()
559+
560+ if "client side error in command starting transaction" in description :
556561 self .skipTest ("Implement PYTHON-1894" )
557- if "timeoutMS applied to entire download" in spec ["description" ]:
562+ if "type=symbol" in description :
563+ self .skipTest ("PyMongo does not support the symbol type" )
564+ if "timeoutms applied to entire download" in description :
558565 self .skipTest ("PyMongo's open_download_stream does not cap the stream's lifetime" )
559566 if any (
560- x in spec [ " description" ]
567+ x in description
561568 for x in [
562- "First insertOne is never committed" ,
563- "Second updateOne is never committed" ,
564- "Third updateOne is never committed" ,
569+ "first insertone is never committed" ,
570+ "second updateone is never committed" ,
571+ "third updateone is never committed" ,
565572 ]
566573 ):
567574 self .skipTest ("Implement PYTHON-4597" )
568575
569- class_name = self .__class__ .__name__ .lower ()
570- description = spec ["description" ].lower ()
571576 if "csot" in class_name :
572577 # Skip tests that are too slow to run on a given platform.
573578 slow_macos = [
@@ -785,6 +790,38 @@ async def _databaseOperation_createCommandCursor(self, target, **kwargs):
785790
786791 return cursor
787792
793+ async def _collectionOperation_assertIndexExists (self , target , ** kwargs ):
794+ collection = self .client [kwargs ["database_name" ]][kwargs ["collection_name" ]]
795+ index_names = [idx ["name" ] async for idx in await collection .list_indexes ()]
796+ self .assertIn (kwargs ["index_name" ], index_names )
797+
798+ async def _collectionOperation_assertIndexNotExists (self , target , ** kwargs ):
799+ collection = self .client [kwargs ["database_name" ]][kwargs ["collection_name" ]]
800+ async for index in await collection .list_indexes ():
801+ self .assertNotEqual (kwargs ["indexName" ], index ["name" ])
802+
803+ async def _collectionOperation_assertCollectionExists (self , target , ** kwargs ):
804+ database_name = kwargs ["database_name" ]
805+ collection_name = kwargs ["collection_name" ]
806+ collection_name_list = await self .client .get_database (database_name ).list_collection_names ()
807+ self .assertIn (collection_name , collection_name_list )
808+
809+ async def _databaseOperation_assertIndexExists (self , target , ** kwargs ):
810+ collection = self .client [kwargs ["database_name" ]][kwargs ["collection_name" ]]
811+ index_names = [idx ["name" ] async for idx in await collection .list_indexes ()]
812+ self .assertIn (kwargs ["index_name" ], index_names )
813+
814+ async def _databaseOperation_assertIndexNotExists (self , target , ** kwargs ):
815+ collection = self .client [kwargs ["database_name" ]][kwargs ["collection_name" ]]
816+ async for index in await collection .list_indexes ():
817+ self .assertNotEqual (kwargs ["indexName" ], index ["name" ])
818+
819+ async def _databaseOperation_assertCollectionExists (self , target , ** kwargs ):
820+ database_name = kwargs ["database_name" ]
821+ collection_name = kwargs ["collection_name" ]
822+ collection_name_list = await self .client .get_database (database_name ).list_collection_names ()
823+ self .assertIn (collection_name , collection_name_list )
824+
788825 async def kill_all_sessions (self ):
789826 if getattr (self , "client" , None ) is None :
790827 return
0 commit comments