22import datetime
33import os
44import unittest
5- from unittest .mock import patch
5+ from unittest .mock import call , patch
66
77import django
88from asgiref .sync import sync_to_async
99from django .contrib .auth .models import User
1010from django .db import connection , transaction
11+ from django .db .backends .utils import CursorDebugWrapper , CursorWrapper
1112from django .db .models import Count
1213from django .db .utils import DatabaseError
1314from django .shortcuts import render
@@ -68,39 +69,59 @@ def test_recording_chunked_cursor(self):
6869 self .assertEqual (len (self .panel ._queries ), 1 )
6970
7071 @patch (
71- "debug_toolbar.panels.sql.tracking.NormalCursorWrapper " ,
72- wraps = sql_tracking .NormalCursorWrapper ,
72+ "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin " ,
73+ wraps = sql_tracking .patch_cursor_wrapper_with_mixin ,
7374 )
74- def test_cursor_wrapper_singleton (self , mock_wrapper ):
75+ def test_cursor_wrapper_singleton (self , mock_patch_cursor_wrapper ):
7576 sql_call ()
76-
7777 # ensure that cursor wrapping is applied only once
78- self .assertEqual (mock_wrapper .call_count , 1 )
78+ self .assertIn (
79+ mock_patch_cursor_wrapper .mock_calls ,
80+ [
81+ [call (CursorWrapper , sql_tracking .NormalCursorMixin )],
82+ # CursorDebugWrapper is used if the test is called with `--debug-sql`
83+ [call (CursorDebugWrapper , sql_tracking .NormalCursorMixin )],
84+ ],
85+ )
7986
8087 @patch (
81- "debug_toolbar.panels.sql.tracking.NormalCursorWrapper " ,
82- wraps = sql_tracking .NormalCursorWrapper ,
88+ "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin " ,
89+ wraps = sql_tracking .patch_cursor_wrapper_with_mixin ,
8390 )
84- def test_chunked_cursor_wrapper_singleton (self , mock_wrapper ):
91+ def test_chunked_cursor_wrapper_singleton (self , mock_patch_cursor_wrapper ):
8592 sql_call (use_iterator = True )
8693
8794 # ensure that cursor wrapping is applied only once
88- self .assertEqual (mock_wrapper .call_count , 1 )
95+ self .assertIn (
96+ mock_patch_cursor_wrapper .mock_calls ,
97+ [
98+ [call (CursorWrapper , sql_tracking .NormalCursorMixin )],
99+ # CursorDebugWrapper is used if the test is called with `--debug-sql`
100+ [call (CursorDebugWrapper , sql_tracking .NormalCursorMixin )],
101+ ],
102+ )
89103
90104 @patch (
91- "debug_toolbar.panels.sql.tracking.NormalCursorWrapper " ,
92- wraps = sql_tracking .NormalCursorWrapper ,
105+ "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin " ,
106+ wraps = sql_tracking .patch_cursor_wrapper_with_mixin ,
93107 )
94- async def test_cursor_wrapper_async (self , mock_wrapper ):
108+ async def test_cursor_wrapper_async (self , mock_patch_cursor_wrapper ):
95109 await sync_to_async (sql_call )()
96110
97- self .assertEqual (mock_wrapper .call_count , 1 )
111+ self .assertIn (
112+ mock_patch_cursor_wrapper .mock_calls ,
113+ [
114+ [call (CursorWrapper , sql_tracking .NormalCursorMixin )],
115+ # CursorDebugWrapper is used if the test is called with `--debug-sql`
116+ [call (CursorDebugWrapper , sql_tracking .NormalCursorMixin )],
117+ ],
118+ )
98119
99120 @patch (
100- "debug_toolbar.panels.sql.tracking.NormalCursorWrapper " ,
101- wraps = sql_tracking .NormalCursorWrapper ,
121+ "debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin " ,
122+ wraps = sql_tracking .patch_cursor_wrapper_with_mixin ,
102123 )
103- async def test_cursor_wrapper_asyncio_ctx (self , mock_wrapper ):
124+ async def test_cursor_wrapper_asyncio_ctx (self , mock_patch_cursor_wrapper ):
104125 self .assertTrue (sql_tracking .allow_sql .get ())
105126 await sync_to_async (sql_call )()
106127
@@ -116,7 +137,21 @@ async def task():
116137 await asyncio .create_task (task ())
117138 # Because it was called in another context, it should not have affected ours
118139 self .assertTrue (sql_tracking .allow_sql .get ())
119- self .assertEqual (mock_wrapper .call_count , 1 )
140+
141+ self .assertIn (
142+ mock_patch_cursor_wrapper .mock_calls ,
143+ [
144+ [
145+ call (CursorWrapper , sql_tracking .NormalCursorMixin ),
146+ call (CursorWrapper , sql_tracking .ExceptionCursorMixin ),
147+ ],
148+ # CursorDebugWrapper is used if the test is called with `--debug-sql`
149+ [
150+ call (CursorDebugWrapper , sql_tracking .NormalCursorMixin ),
151+ call (CursorDebugWrapper , sql_tracking .ExceptionCursorMixin ),
152+ ],
153+ ],
154+ )
120155
121156 def test_generate_server_timing (self ):
122157 self .assertEqual (len (self .panel ._queries ), 0 )
0 commit comments