Skip to content

Commit 9ddae6c

Browse files
committed
Added unit test for new 4.1 metadata
1 parent 343b168 commit 9ddae6c

File tree

2 files changed

+269
-0
lines changed

2 files changed

+269
-0
lines changed

tests/unit/io/conftest.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
# limitations under the License.
2020

2121

22+
from collections import deque
23+
from struct import unpack as struct_unpack
24+
2225
import pytest
2326

2427
from neo4j.io._common import MessageInbox
28+
from neo4j.packstream import UnpackableBuffer, Unpacker
2529

2630

2731
class FakeSocket:
@@ -54,6 +58,83 @@ def pop_message(self):
5458
return self.messages.pop()
5559

5660

61+
class FakeSocket2:
62+
63+
def __init__(self, address=None, on_send=None):
64+
self.address = address
65+
self.recv_buffer = bytearray()
66+
self._messages = MessageInbox(self, on_error=print)
67+
self.on_send = on_send
68+
69+
def getsockname(self):
70+
return "127.0.0.1", 0xFFFF
71+
72+
def getpeername(self):
73+
return self.address
74+
75+
def recv_into(self, buffer, nbytes):
76+
data = self.recv_buffer[:nbytes]
77+
actual = len(data)
78+
buffer[:actual] = data
79+
self.recv_buffer = self.recv_buffer[actual:]
80+
return actual
81+
82+
def sendall(self, data):
83+
if callable(self.on_send):
84+
self.on_send(data)
85+
86+
def close(self):
87+
return
88+
89+
def inject(self, data):
90+
self.recv_buffer += data
91+
92+
def pop_chunk(self):
93+
chunk_size, = struct_unpack(">H", self.recv_buffer[:2])
94+
print("CHUNK SIZE %r" % chunk_size)
95+
end = 2 + chunk_size
96+
chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:]
97+
return chunk_data
98+
99+
def pop_message(self):
100+
data = bytearray()
101+
while True:
102+
chunk = self.pop_chunk()
103+
print("CHUNK %r" % chunk)
104+
if chunk:
105+
data.extend(chunk)
106+
elif data:
107+
break # end of message
108+
else:
109+
continue # NOOP
110+
header = data[0]
111+
n_fields = header % 0x10
112+
tag = data[1]
113+
buffer = UnpackableBuffer(data[2:])
114+
unpacker = Unpacker(buffer)
115+
fields = [unpacker.unpack() for _ in range(n_fields)]
116+
return tag, fields
117+
118+
119+
class FakeSocketPair:
120+
121+
def __init__(self, address):
122+
self.client = FakeSocket2(address)
123+
self.server = FakeSocket2()
124+
self.client.on_send = self.server.inject
125+
self.server.on_send = self.client.inject
126+
127+
57128
@pytest.fixture
58129
def fake_socket():
59130
return FakeSocket
131+
132+
133+
@pytest.fixture
134+
def fake_socket_2():
135+
return FakeSocket2
136+
137+
138+
@pytest.fixture
139+
def fake_socket_pair():
140+
return FakeSocketPair
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2020 "Neo4j,"
5+
# Neo4j Sweden AB [http://neo4j.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
import pytest
23+
24+
from neo4j.io._bolt4 import Bolt4x1
25+
from neo4j.conf import PoolConfig
26+
27+
28+
def test_conn_timed_out(fake_socket):
29+
address = ("127.0.0.1", 7687)
30+
max_connection_lifetime = 0
31+
connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime)
32+
assert connection.timedout() is True
33+
34+
35+
def test_conn_not_timed_out_if_not_enabled(fake_socket):
36+
address = ("127.0.0.1", 7687)
37+
max_connection_lifetime = -1
38+
connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime)
39+
assert connection.timedout() is False
40+
41+
42+
def test_conn_not_timed_out(fake_socket):
43+
address = ("127.0.0.1", 7687)
44+
max_connection_lifetime = 999999999
45+
connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime)
46+
assert connection.timedout() is False
47+
48+
49+
def test_db_extra_in_begin(fake_socket):
50+
address = ("127.0.0.1", 7687)
51+
socket = fake_socket(address)
52+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
53+
connection.begin(db="something")
54+
connection.send_all()
55+
tag, fields = socket.pop_message()
56+
assert tag == b"\x11"
57+
assert len(fields) == 1
58+
assert fields[0] == {"db": "something"}
59+
60+
61+
def test_db_extra_in_run(fake_socket):
62+
address = ("127.0.0.1", 7687)
63+
socket = fake_socket(address)
64+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
65+
connection.run("", {}, db="something")
66+
connection.send_all()
67+
tag, fields = socket.pop_message()
68+
assert tag == b"\x10"
69+
assert len(fields) == 3
70+
assert fields[0] == ""
71+
assert fields[1] == {}
72+
assert fields[2] == {"db": "something"}
73+
74+
75+
def test_n_extra_in_discard(fake_socket):
76+
address = ("127.0.0.1", 7687)
77+
socket = fake_socket(address)
78+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
79+
connection.discard(n=666)
80+
connection.send_all()
81+
tag, fields = socket.pop_message()
82+
assert tag == b"\x2F"
83+
assert len(fields) == 1
84+
assert fields[0] == {"n": 666}
85+
86+
87+
@pytest.mark.parametrize(
88+
"test_input, expected",
89+
[
90+
(666, {"n": -1, "qid": 666}),
91+
(-1, {"n": -1}),
92+
]
93+
)
94+
def test_qid_extra_in_discard(fake_socket, test_input, expected):
95+
address = ("127.0.0.1", 7687)
96+
socket = fake_socket(address)
97+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
98+
connection.discard(qid=test_input)
99+
connection.send_all()
100+
tag, fields = socket.pop_message()
101+
assert tag == b"\x2F"
102+
assert len(fields) == 1
103+
assert fields[0] == expected
104+
105+
106+
@pytest.mark.parametrize(
107+
"test_input, expected",
108+
[
109+
(777, {"n": 666, "qid": 777}),
110+
(-1, {"n": 666}),
111+
]
112+
)
113+
def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
114+
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
115+
address = ("127.0.0.1", 7687)
116+
socket = fake_socket(address)
117+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
118+
connection.discard(n=666, qid=test_input)
119+
connection.send_all()
120+
tag, fields = socket.pop_message()
121+
assert tag == b"\x2F"
122+
assert len(fields) == 1
123+
assert fields[0] == expected
124+
125+
126+
@pytest.mark.parametrize(
127+
"test_input, expected",
128+
[
129+
(666, {"n": 666}),
130+
(-1, {"n": -1}),
131+
]
132+
)
133+
def test_n_extra_in_pull(fake_socket, test_input, expected):
134+
address = ("127.0.0.1", 7687)
135+
socket = fake_socket(address)
136+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
137+
connection.pull(n=test_input)
138+
connection.send_all()
139+
tag, fields = socket.pop_message()
140+
assert tag == b"\x3F"
141+
assert len(fields) == 1
142+
assert fields[0] == expected
143+
144+
145+
@pytest.mark.parametrize(
146+
"test_input, expected",
147+
[
148+
(777, {"n": -1, "qid": 777}),
149+
(-1, {"n": -1}),
150+
]
151+
)
152+
def test_qid_extra_in_pull(fake_socket, test_input, expected):
153+
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
154+
address = ("127.0.0.1", 7687)
155+
socket = fake_socket(address)
156+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
157+
connection.pull(qid=test_input)
158+
connection.send_all()
159+
tag, fields = socket.pop_message()
160+
assert tag == b"\x3F"
161+
assert len(fields) == 1
162+
assert fields[0] == expected
163+
164+
165+
def test_n_and_qid_extras_in_pull(fake_socket):
166+
address = ("127.0.0.1", 7687)
167+
socket = fake_socket(address)
168+
connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime)
169+
connection.pull(n=666, qid=777)
170+
connection.send_all()
171+
tag, fields = socket.pop_message()
172+
assert tag == b"\x3F"
173+
assert len(fields) == 1
174+
assert fields[0] == {"n": 666, "qid": 777}
175+
176+
177+
def test_hello_passes_routing_metadata(fake_socket_pair):
178+
address = ("127.0.0.1", 7687)
179+
sockets = fake_socket_pair(address)
180+
# TODO helper method for encoding messages
181+
sockets.server.sendall(b"\x00\x03\xB1\x70\xA0\x00\x00")
182+
connection = Bolt4x1(address, sockets.client, PoolConfig.max_connection_lifetime,
183+
routing_context={"foo": "bar"})
184+
connection.hello()
185+
tag, fields = sockets.server.pop_message()
186+
assert tag == 0x01
187+
assert len(fields) == 1
188+
assert fields[0]["routing"] == {"foo": "bar"}

0 commit comments

Comments
 (0)