|
28 | 28 | Transaction, |
29 | 29 | unit_of_work, |
30 | 30 | ) |
| 31 | +from neo4j.io import IOPool |
31 | 32 |
|
32 | 33 | from ._fake_connection import FakeConnection |
33 | 34 |
|
34 | 35 |
|
35 | | -@pytest.fixture() |
| 36 | +@pytest.fixture |
36 | 37 | def pool(mocker): |
37 | | - pool = mocker.MagicMock() |
38 | | - pool.acquire = mocker.MagicMock(side_effect=iter(FakeConnection, 0)) |
| 38 | + pool = mocker.Mock(spec=IOPool) |
| 39 | + assert not hasattr(pool, "acquired_connection_mocks") |
| 40 | + pool.acquired_connection_mocks = [] |
| 41 | + |
| 42 | + def acquire_side_effect(*_, **__): |
| 43 | + connection = FakeConnection() |
| 44 | + pool.acquired_connection_mocks.append(connection) |
| 45 | + return connection |
| 46 | + |
| 47 | + pool.acquire.side_effect = acquire_side_effect |
39 | 48 | return pool |
40 | 49 |
|
41 | 50 |
|
@@ -252,3 +261,49 @@ def work(tx): |
252 | 261 | session.write_transaction(work) |
253 | 262 | else: |
254 | 263 | raise ValueError(run_type) |
| 264 | + |
| 265 | + |
| 266 | +@pytest.mark.parametrize( |
| 267 | + ("params", "kw_params", "expected_params"), |
| 268 | + ( |
| 269 | + ({"x": 1}, {}, {"x": 1}), |
| 270 | + ({}, {"x": 1}, {"x": 1}), |
| 271 | + ({"x": 1}, {"y": 2}, {"x": 1, "y": 2}), |
| 272 | + ({"x": 1}, {"x": 2}, {"x": 2}), |
| 273 | + ({"x": 1}, {"x": 2}, {"x": 2}), |
| 274 | + ({"x": 1, "y": 3}, {"x": 2}, {"x": 2, "y": 3}), |
| 275 | + ({"x": 1}, {"x": 2, "y": 3}, {"x": 2, "y": 3}), |
| 276 | + # potentially internally used keyword arguments |
| 277 | + ({}, {"timeout": 2}, {"timeout": 2}), |
| 278 | + ({"timeout": 2}, {}, {"timeout": 2}), |
| 279 | + ({}, {"imp_user": "hans"}, {"imp_user": "hans"}), |
| 280 | + ({"imp_user": "hans"}, {}, {"imp_user": "hans"}), |
| 281 | + ({}, {"db": "neo4j"}, {"db": "neo4j"}), |
| 282 | + ({"db": "neo4j"}, {}, {"db": "neo4j"}), |
| 283 | + ({}, {"database": "neo4j"}, {"database": "neo4j"}), |
| 284 | + ({"database": "neo4j"}, {}, {"database": "neo4j"}), |
| 285 | + ) |
| 286 | +) |
| 287 | +@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) |
| 288 | +def test_session_run_parameter_precedence( |
| 289 | + pool, params, kw_params, expected_params, run_type |
| 290 | +): |
| 291 | + with Session(pool, SessionConfig()) as session: |
| 292 | + if run_type == "auto": |
| 293 | + session.run("RETURN $x", params, **kw_params) |
| 294 | + elif run_type == "unmanaged": |
| 295 | + tx = session.begin_transaction() |
| 296 | + tx.run("RETURN $x", params, **kw_params) |
| 297 | + elif run_type == "managed": |
| 298 | + def work(tx): |
| 299 | + tx.run("RETURN $x", params, **kw_params) |
| 300 | + session.write_transaction(work) |
| 301 | + else: |
| 302 | + raise ValueError(run_type) |
| 303 | + |
| 304 | + assert len(pool.acquired_connection_mocks) == 1 |
| 305 | + connection_mock = pool.acquired_connection_mocks[0] |
| 306 | + connection_mock.run.assert_called_once() |
| 307 | + call_args, call_kwargs = connection_mock.run.call_args |
| 308 | + assert call_args[0] == "RETURN $x" |
| 309 | + assert call_kwargs["parameters"] == expected_params |
0 commit comments