Skip to content

Commit ab61bc1

Browse files
author
Stefan Plantikow
committed
Merge pull request #21 from boggle/new-results-api
New results api
2 parents 7f1a4c5 + aec108d commit ab61bc1

File tree

4 files changed

+154
-41
lines changed

4 files changed

+154
-41
lines changed

docs/source/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ Session API
2424
.. autoclass:: neo4j.v1.Record
2525
:members:
2626

27+
.. autofunction:: neo4j.v1.record
28+
2729
.. autoclass:: neo4j.v1.Result
2830
:members:
2931

neo4j/v1/connection.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,7 @@ def fetch_next(self):
262262
# Unpack from the raw byte stream and call the relevant message handler(s)
263263
raw.seek(0)
264264
response = self.responses[0]
265-
for message in unpack():
266-
signature = message.signature
267-
fields = tuple(message)
265+
for signature, fields in unpack():
268266
if __debug__:
269267
log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields)))
270268
handler_name = "on_%s" % message_names[signature].lower()

neo4j/v1/session.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -473,48 +473,95 @@ def close(self):
473473
self.closed = True
474474
self.session.transaction = None
475475

476-
477476
class Record(object):
478-
""" Record object for storing result values along with field names.
477+
""" Record is an ordered collection of fields.
478+
479+
A Record object is used for storing result values along with field names.
479480
Fields can be accessed by numeric or named index (``record[0]`` or
480-
``record["field"]``) or by attribute (``record.field``).
481+
``record["field"]``).
481482
"""
482483

483484
def __init__(self, keys, values):
484-
self.__keys__ = keys
485-
self.__values__ = values
485+
self._keys = tuple(keys)
486+
self._values = tuple(values)
487+
488+
def keys(self):
489+
""" Return the keys (key names) of the record
490+
"""
491+
return self._keys
492+
493+
def values(self):
494+
""" Return the values of the record
495+
"""
496+
return self._values
497+
498+
def items(self):
499+
""" Return the fields of the record as a list of key and value tuples
500+
"""
501+
return zip(self._keys, self._values)
502+
503+
def index(self, key):
504+
""" Return the index of the given key
505+
"""
506+
try:
507+
return self._keys.index(key)
508+
except ValueError:
509+
raise KeyError(key)
510+
511+
def __record__(self):
512+
return self
513+
514+
def __contains__(self, key):
515+
return self._keys.__contains__(key)
516+
517+
def __iter__(self):
518+
return iter(self._keys)
519+
520+
def copy(self):
521+
return Record(self._keys, self._values)
522+
523+
def __getitem__(self, item):
524+
if isinstance(item, string):
525+
return self._values[self.index(item)]
526+
elif isinstance(item, integer):
527+
return self._values[item]
528+
else:
529+
raise TypeError(item)
530+
531+
def __len__(self):
532+
return len(self._keys)
486533

487534
def __repr__(self):
488-
values = self.__values__
535+
values = self._values
489536
s = []
490-
for i, field in enumerate(self.__keys__):
537+
for i, field in enumerate(self._keys):
491538
s.append("%s=%r" % (field, values[i]))
492539
return "<Record %s>" % " ".join(s)
493540

541+
def __hash__(self):
542+
return hash(self._keys) ^ hash(self._values)
543+
494544
def __eq__(self, other):
495545
try:
496-
return vars(self) == vars(other)
497-
except TypeError:
498-
return tuple(self) == tuple(other)
546+
return self._keys == tuple(other.keys()) and self._values == tuple(other.values())
547+
except AttributeError:
548+
return False
499549

500550
def __ne__(self, other):
501551
return not self.__eq__(other)
502552

503-
def __len__(self):
504-
return self.__keys__.__len__()
553+
def record(obj):
554+
""" Obtain an immutable record for the given object
555+
(either by calling obj.__record__() or by copying out the record data)
556+
"""
557+
try:
558+
return obj.__record__()
559+
except AttributeError:
560+
keys = obj.keys()
561+
values = []
562+
for key in keys:
563+
values.append(obj[key])
564+
return Record(keys, values)
565+
505566

506-
def __getitem__(self, item):
507-
if isinstance(item, string):
508-
return getattr(self, item)
509-
elif isinstance(item, integer):
510-
return getattr(self, self.__keys__[item])
511-
else:
512-
raise TypeError(item)
513567

514-
def __getattr__(self, item):
515-
try:
516-
i = self.__keys__.index(item)
517-
except ValueError:
518-
raise AttributeError("No key %r" % item)
519-
else:
520-
return self.__values__[i]

test/test_session.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from unittest import TestCase
2323

24-
from neo4j.v1.session import GraphDatabase, CypherError
24+
from neo4j.v1.session import GraphDatabase, CypherError, Record, record
2525
from neo4j.v1.typesystem import Node, Relationship, Path
2626

2727

@@ -36,11 +36,11 @@ def test_can_run_simple_statement(self):
3636
for record in session.run("RETURN 1 AS n"):
3737
assert record[0] == 1
3838
assert record["n"] == 1
39-
with self.assertRaises(AttributeError):
39+
with self.assertRaises(KeyError):
40+
_ = record["x"]
41+
assert record["n"] == 1
42+
with self.assertRaises(KeyError):
4043
_ = record["x"]
41-
assert record.n == 1
42-
with self.assertRaises(AttributeError):
43-
_ = record.x
4444
with self.assertRaises(TypeError):
4545
_ = record[object()]
4646
assert repr(record)
@@ -77,7 +77,6 @@ def test_can_run_simple_statement_from_bytes_string(self):
7777
for record in session.run(b"RETURN 1 AS n"):
7878
assert record[0] == 1
7979
assert record["n"] == 1
80-
assert record.n == 1
8180
assert repr(record)
8281
assert len(record) == 1
8382
count += 1
@@ -138,12 +137,6 @@ def test_can_handle_cypher_error(self):
138137
with self.assertRaises(CypherError):
139138
session.run("X")
140139

141-
def test_record_equality(self):
142-
with GraphDatabase.driver("bolt://localhost").session() as session:
143-
result = session.run("unwind([1, 1]) AS a RETURN a")
144-
assert result[0] == result[1]
145-
assert result[0] != "this is not a record"
146-
147140
def test_can_obtain_summary_info(self):
148141
with GraphDatabase.driver("bolt://localhost").session() as session:
149142
result = session.run("CREATE (n) RETURN n")
@@ -211,6 +204,79 @@ def test_can_obtain_notification_info(self):
211204
assert position.column == 1
212205

213206

207+
class RecordTestCase(TestCase):
208+
def test_record_equality(self):
209+
record1 = Record(["name","empire"], ["Nigel", "The British Empire"])
210+
record2 = Record(["name","empire"], ["Nigel", "The British Empire"])
211+
record3 = Record(["name","empire"], ["Stefan", "Das Deutschland"])
212+
assert record1 == record2
213+
assert record1 != record3
214+
assert record2 != record3
215+
216+
def test_record_hashing(self):
217+
record1 = Record(["name","empire"], ["Nigel", "The British Empire"])
218+
record2 = Record(["name","empire"], ["Nigel", "The British Empire"])
219+
record3 = Record(["name","empire"], ["Stefan", "Das Deutschland"])
220+
assert hash(record1) == hash(record2)
221+
assert hash(record1) != hash(record3)
222+
assert hash(record2) != hash(record3)
223+
224+
def test_record_keys(self):
225+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
226+
assert list(aRecord.keys()) == ["name", "empire"]
227+
228+
def test_record_values(self):
229+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
230+
assert list(aRecord.values()) == ["Nigel", "The British Empire"]
231+
232+
def test_record_items(self):
233+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
234+
assert list(aRecord.items()) == [("name", "Nigel"), ("empire", "The British Empire")]
235+
236+
def test_record_index(self):
237+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
238+
assert aRecord.index("name") == 0
239+
assert aRecord.index("empire") == 1
240+
with self.assertRaises(KeyError):
241+
aRecord.index("crap")
242+
243+
def test_record_contains(self):
244+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
245+
assert "name" in aRecord
246+
assert "empire" in aRecord
247+
assert "Germans" not in aRecord
248+
249+
def test_record_iter(self):
250+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
251+
assert list(aRecord.__iter__()) == ["name", "empire"]
252+
253+
def test_record_record(self):
254+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
255+
assert record(aRecord) is aRecord
256+
257+
def test_record_copy(self):
258+
original = Record(["name","empire"], ["Nigel", "The British Empire"])
259+
duplicate = original.copy()
260+
assert dict(original) == dict(duplicate)
261+
assert original.keys() == duplicate.keys()
262+
assert original is not duplicate
263+
264+
def test_record_as_dict(self):
265+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
266+
assert dict(aRecord) == { "name": "Nigel", "empire": "The British Empire" }
267+
268+
def test_record_as_list(self):
269+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
270+
assert list(aRecord) == ["name", "empire"]
271+
272+
def test_record_len(self):
273+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
274+
assert len(aRecord) == 2
275+
276+
def test_record_repr(self):
277+
aRecord = Record(["name","empire"], ["Nigel", "The British Empire"])
278+
assert repr(aRecord) == "<Record name='Nigel' empire='The British Empire'>"
279+
214280
class TransactionTestCase(TestCase):
215281
def test_can_commit_transaction(self):
216282
with GraphDatabase.driver("bolt://localhost").session() as session:

0 commit comments

Comments
 (0)