Skip to content

Commit b12b641

Browse files
committed
Fix SesstionTransaction._connection_for_bind call
1 parent ea5c5ee commit b12b641

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pytest_flask_sqlalchemy/fixtures.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import os
21
import contextlib
2+
import os
33

44
import pytest
55
import sqlalchemy as sa
@@ -144,6 +144,12 @@ def raw_connection():
144144

145145
engine.raw_connection = raw_connection
146146

147+
# Fix SessionTransaction._connection_for_bind caching
148+
@sa.event.listens_for(session, 'after_begin')
149+
def after_begin(session, transaction, conn):
150+
if engine not in transaction._connections:
151+
transaction._connections[engine] = transaction._connections[conn]
152+
147153
for mocked_engine in pytestconfig._mocked_engines:
148154
mocker.patch(mocked_engine, new=engine)
149155

tests/test_fixtures.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,32 @@ def test_delete_message(account_address, db_session):
336336

337337
result = db_testdir.runpytest()
338338
result.assert_outcomes(passed=1)
339+
340+
341+
def test_rollback_nested(db_testdir):
342+
'''
343+
Test that creating objects and emitting SQL in the ORM won't bleed into
344+
other tests.
345+
'''
346+
# Load tests from file
347+
db_testdir.makepyfile("""
348+
def test_rollback_nested(person, db_session, caplog):
349+
assert db_session.query(person).count() == 0
350+
n1 = db_session.begin_nested()
351+
db_session.add(person())
352+
assert db_session.query(person).count() == 1
353+
354+
n2 = db_session.begin_nested()
355+
db_session.add(person())
356+
assert db_session.query(person).count() == 2
357+
358+
n2.rollback()
359+
print(db_session.bind.mock_calls)
360+
assert db_session.query(person).count() == 1
361+
n1.rollback()
362+
assert db_session.query(person).count() == 0
363+
""")
364+
365+
# Run tests
366+
result = db_testdir.runpytest()
367+
result.assert_outcomes(passed=1)

0 commit comments

Comments
 (0)