22import os
33
44from django .core .exceptions import ImproperlyConfigured
5+ from django .db import DEFAULT_DB_ALIAS
56from django .db .backends .base .base import BaseDatabaseWrapper
7+ from django .db .backends .utils import debug_transaction
68from django .utils .asyncio import async_unsafe
79from django .utils .functional import cached_property
810from pymongo .collection import Collection
@@ -32,6 +34,17 @@ def __exit__(self, exception_type, exception_value, exception_traceback):
3234 pass
3335
3436
37+ def requires_transaction_support (func ):
38+ """Make a method a no-op if transactions aren't supported."""
39+
40+ def wrapper (self , * args , ** kwargs ):
41+ if not self .features ._supports_transactions :
42+ return
43+ func (self , * args , ** kwargs )
44+
45+ return wrapper
46+
47+
3548class DatabaseWrapper (BaseDatabaseWrapper ):
3649 data_types = {
3750 "AutoField" : "int" ,
@@ -142,6 +155,10 @@ def _isnull_operator(a, b):
142155 ops_class = DatabaseOperations
143156 validation_class = DatabaseValidation
144157
158+ def __init__ (self , settings_dict , alias = DEFAULT_DB_ALIAS ):
159+ super ().__init__ (settings_dict , alias = alias )
160+ self .session = None
161+
145162 def get_collection (self , name , ** kwargs ):
146163 collection = Collection (self .database , name , ** kwargs )
147164 if self .queries_logged :
@@ -212,6 +229,10 @@ def close(self):
212229
213230 def close_pool (self ):
214231 """Close the MongoClient."""
232+ # Clear commit hooks and session.
233+ self .run_on_commit = []
234+ if self .session :
235+ self ._end_session ()
215236 connection = self .connection
216237 if connection is None :
217238 return
@@ -227,6 +248,56 @@ def close_pool(self):
227248 def cursor (self ):
228249 return Cursor ()
229250
251+ @requires_transaction_support
252+ def validate_no_broken_transaction (self ):
253+ super ().validate_no_broken_transaction ()
254+
230255 def get_database_version (self ):
231256 """Return a tuple of the database's version."""
232257 return tuple (self .connection .server_info ()["versionArray" ])
258+
259+ @requires_transaction_support
260+ def _start_transaction (self , autocommit , force_begin_transaction_with_broken_autocommit = False ):
261+ # Besides @transaction.atomic() (which uses
262+ # _start_transaction_under_autocommit(), disabling autocommit is
263+ # another way to start a transaction.
264+ # if not autocommit:
265+ # self._start_transaction()
266+ # def _start_transaction(self):
267+ # Private API, specific to this backend.
268+ if self .session is None :
269+ self .session = self .connection .start_session ()
270+ with debug_transaction (self , "session.start_transaction()" ):
271+ self .session .start_transaction ()
272+
273+ @requires_transaction_support
274+ def _commit_transaction (self ):
275+ self .validate_thread_sharing ()
276+ self .validate_no_atomic_block ()
277+ if self .session :
278+ with debug_transaction (self , "session.commit_transaction()" ):
279+ self .session .commit_transaction ()
280+ self ._end_session ()
281+ # A successful commit means that the database connection works.
282+ self .errors_occurred = False
283+ self .run_commit_hooks_on_set_autocommit_on = True
284+
285+ @async_unsafe
286+ @requires_transaction_support
287+ def _rollback_transaction (self ):
288+ """Roll back a MongoDB transaction and reset the dirty flag."""
289+ self .validate_thread_sharing ()
290+ self .validate_no_atomic_block ()
291+ if self .session :
292+ with debug_transaction (self , "session.abort_transaction()" ):
293+ self .session .abort_transaction ()
294+ self ._end_session ()
295+ # A successful rollback means that the database connection works.
296+ self .errors_occurred = False
297+ self .needs_rollback = False
298+ self .run_on_commit = []
299+
300+ def _end_session (self ):
301+ # Private API, specific to this backend.
302+ self .session .end_session ()
303+ self .session = None
0 commit comments