Skip to content

Commit 0386c8a

Browse files
committed
Build Redis arguments instead of intermediate files
1 parent 0374c05 commit 0386c8a

File tree

1 file changed

+133
-140
lines changed

1 file changed

+133
-140
lines changed

bulk_insert.py

Lines changed: 133 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import csv
22
import os
3-
import errno
43
import struct
54
import redis
65
import click
76

7+
NODE_COUNT = 0
8+
RELATION_COUNT = 0
9+
NODE_DICT = {}
10+
811
# Custom error class for invalid inputs
912
class CSVError(Exception):
1013
pass
@@ -58,141 +61,138 @@ def __init__(self, prop_str):
5861
def to_binary(self):
5962
return struct.pack(self.format_str, *[self.type] + self.pack_args)
6063

61-
class Relation:
62-
def __init__(self, line):
63-
self.src = NODE_DICT[line[0]]
64-
self.dest = NODE_DICT[line[1]]
65-
# self.props = []
66-
# for field in line[2:]:
67-
# self.props.append(Property(field))
64+
class EntityFile(object):
65+
def __init__(self, filename):
66+
# The label or relation type string is the basename of the file
67+
self.entity_str = os.path.splitext(os.path.basename(filename))[0].encode("ascii")
68+
# Input file handling
69+
self.infile = open(filename, 'rt')
70+
# Initialize CSV reader that ignores leading whitespace in each field and does not modify input quote characters
71+
self.reader = csv.reader(self.infile, skipinitialspace=True, quoting=csv.QUOTE_NONE)
72+
73+
self.prop_offset = 0 # Starting index of properties in row
74+
self.expected_col_count = 0
75+
self.prop_count = 0 # Number of properties per entity
76+
77+
self.entity_count = 0 # Total number of entities
78+
79+
self.packed_header = ""
80+
self.entities = []
81+
82+
# entity_string refers to label or relation type string
83+
def pack_header(self, header):
84+
prop_count = len(header) - self.prop_offset
85+
# String format
86+
# Is == length, string
87+
fmt = "=I%dsI" % len(self.entity_str) # Unaligned native, entity_string length, entity_string string, count of properties
88+
args = [len(self.entity_str), self.entity_str, prop_count]
89+
for prop in header[self.prop_offset:]:
90+
fmt += "I%ds" % len(prop)
91+
args += [len(prop), prop]
92+
return struct.pack(fmt, *args)
93+
94+
def pack_props(self, line):
95+
props = []
96+
for field in line[self.prop_offset:]:
97+
props.append(Property(field))
98+
99+
return "".join(p.to_binary() for p in props)
68100

69101
def to_binary(self):
70-
fmt = "=QQ" # 8-byte unsigned ints for src and dest
71-
return struct.pack(fmt, self.src, self.dest)
72-
73-
74-
WORKING_DIR = "working_"
75-
NODE_COUNT = 0
76-
RELATION_COUNT = 0
77-
NODE_DICT = {}
78-
NODEFILES = []
79-
RELFILES = []
80-
81-
# This function applies to both node and relation files
82-
def pack_header(label, line):
83-
prop_count = len(line)
84-
# String format
85-
# Is == length, string
86-
fmt = "=I%dsI" % len(label) # Unaligned native, label length, label string, count of properties
87-
args = [len(label), label, prop_count]
88-
for prop in line:
89-
fmt += "I%ds" % len(prop)
90-
args += [len(prop), prop]
91-
return struct.pack(fmt, *args)
92-
93-
def pack_props(line):
94-
# prop_count = len(line)
95-
props = []
96-
# struct.pack()
97-
for field in line:
98-
props.append(Property(field))
99-
100-
return props
102+
return self.packed_header + "".join(self.entities)
103+
104+
class NodeFile(EntityFile):
105+
def __init__(self, infile):
106+
super(NodeFile, self).__init__(infile)
107+
self.process_header()
108+
109+
def process_header(self):
110+
# Header format:
111+
# source identifier, dest identifier, properties[0..n]
112+
header = next(self.reader)
113+
self.expected_col_count = len(header)
114+
# If identifier field begins with an underscore, don't add it as a property.
115+
if header[0][0] == '_':
116+
self.prop_offset = 1
117+
self.packed_header = self.pack_header(header)
118+
119+
def process_entities(self):
120+
global NODE_DICT
121+
global NODE_COUNT
122+
123+
for row in self.reader:
124+
# Expect all entities to have the same property count
125+
if len(row) != self.expected_col_count:
126+
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
127+
% (self.infile.name, self.reader.line_num, self.expected_col_count, len(row), ','.join(row)))
128+
# Check for dangling commma
129+
if row[-1] == ',':
130+
raise CSVError("%s:%d Dangling comma in input. ('%s')"
131+
% (self.infile.name, self.reader.line_num, ','.join(row)))
132+
# Add identifier->ID pair to dictionary
133+
if row[0] in NODE_DICT:
134+
print("Node identifier '%s' was used multiple times - second occurrence at %s:%d" % (row[0], self.infile.name, self.reader.line_num))
135+
NODE_DICT[row[0]] = NODE_COUNT
136+
NODE_COUNT += 1
137+
self.entity_count += 1
138+
self.entities.append(self.pack_props(row))
139+
self.infile.close()
140+
141+
142+
class RelationFile(EntityFile):
143+
def __init__(self, infile):
144+
super(RelationFile, self).__init__(infile)
145+
self.process_header()
146+
147+
def process_header(self):
148+
# Header format:
149+
# source identifier, dest identifier, properties[0..n]
150+
header = next(self.reader)
151+
# Assume rectangular CSVs
152+
self.expected_col_count = len(header)
153+
self.prop_count = self.expected_col_count - 2
154+
if self.prop_count < 0:
155+
print("Relation file '%s' should have at least 2 elements in header line." % (self.infile.name))
156+
return # TODO return?
157+
158+
self.prop_offset = 2
159+
self.packed_header = self.pack_header(header) # skip src and dest identifiers
160+
161+
def process_entities(self):
162+
global RELATION_COUNT
163+
for row in self.reader:
164+
# Each row should have the same number of fields
165+
if len(row) != self.expected_col_count:
166+
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
167+
% (self.infile.name, self.reader.line_num, self.expected_col_count, len(row), ','.join(row)))
168+
# Check for dangling commma
169+
if row[-1] == '':
170+
raise CSVError("%s:%d Dangling comma in input. ('%s')"
171+
% (self.infile.name, self.reader.line_num, ','.join(row)))
172+
173+
src = NODE_DICT[row[0]]
174+
dest = NODE_DICT[row[1]]
175+
fmt = "=QQ" # 8-byte unsigned ints for src and dest
176+
self.entity_count += 1
177+
RELATION_COUNT += 1
178+
self.entities.append(struct.pack(fmt, src, dest) + self.pack_props(row))
179+
self.infile.close()
101180

102181
def process_node_csvs(csvs):
103-
global NODE_COUNT
104-
global NODEFILES
105-
global NODE_DICT
106-
# A Label or Relationship name is set by the CSV file name
107-
# TODO validate name string
182+
nodefiles = []
108183
for in_csv in csvs:
109-
filename = os.path.basename(in_csv)
110-
label = os.path.splitext(filename)[0].encode("ascii")
111-
112-
with open(os.path.join(WORKING_DIR, label + ".dat"), 'wb') as outfile, open(in_csv, 'rt') as infile:
113-
NODEFILES.append(os.path.join(os.getcwd(), outfile.name))
114-
# Initialize CSV reader that ignores leading whitespace in each field and does not modify input quote characters
115-
reader = csv.reader(infile, skipinitialspace=True, quoting=csv.QUOTE_NONE)
116-
# Header format:
117-
# identifier, properties[0..n]
118-
header = next(reader)
119-
# Assume rectangular CSVs
120-
expected_col_count = len(header) # id field + prop_count
121-
# TODO verify that header is not empty
122-
123-
properties_start = 0
124-
# If identifier field begins with an underscore, don't add it as a property.
125-
if header[0][0] == '_':
126-
properties_start = 1
127-
128-
out = pack_header(label, header[properties_start:])
129-
outfile.write(out)
130-
131-
for row in reader:
132-
# Expect all entities to have the same property count
133-
if len(row) != expected_col_count:
134-
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
135-
% (filename, reader.line_num, expected_col_count, len(row), ','.join(row)))
136-
# Check for dangling commma
137-
if row[-1] == ',':
138-
raise CSVError("%s:%d Dangling comma in input. ('%s')"
139-
% (filename, reader.line_num, ','.join(row)))
140-
# Add identifier->ID pair to dictionary
141-
if row[0] in NODE_DICT:
142-
print("Node identifier '%s' was used multiple times - second occurrence at %s:%d" % (row[0], filename, reader.line_num))
143-
NODE_DICT[row[0]] = NODE_COUNT
144-
NODE_COUNT += 1
145-
props = pack_props(row[properties_start:])
146-
for prop in props:
147-
outfile.write(prop.to_binary())
148-
149-
return NODE_DICT
184+
nodefile = NodeFile(in_csv)
185+
nodefile.process_entities()
186+
nodefiles.append(nodefile)
187+
return nodefiles
150188

151189
def process_relation_csvs(csvs):
152-
global RELATION_COUNT
153-
global RELFILES
154-
# A Label or Relationship name is set by the CSV file name
155-
# TODO validate name string
190+
relfiles = []
156191
for in_csv in csvs:
157-
filename = os.path.basename(in_csv)
158-
relation = os.path.splitext(filename)[0].encode("ascii")
159-
160-
with open(os.path.join(WORKING_DIR, relation + ".dat"), 'wb') as outfile, open(in_csv, 'rt') as infile:
161-
RELFILES.append(os.path.join(os.getcwd(), outfile.name))
162-
# Initialize CSV reader that ignores leading whitespace in each field and does not modify input quote characters
163-
reader = csv.reader(infile, skipinitialspace=True, quoting=csv.QUOTE_NONE)
164-
# Header format:
165-
# source identifier, dest identifier, properties[0..n]
166-
header = next(reader)
167-
# Assume rectangular CSVs
168-
expected_col_count = len(header) # src + dest + prop_count
169-
170-
relations_have_properties = False
171-
if expected_col_count < 2:
172-
print("Relation file '%s' should have at least 2 elements in header line." % (filename))
173-
return
174-
elif expected_col_count > 2:
175-
relations_have_properties = True
176-
177-
out = pack_header(relation, header[2:])
178-
outfile.write(out)
179-
180-
for row in reader:
181-
# Each row should have the same number of fields
182-
if len(row) != expected_col_count:
183-
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
184-
% (filename, reader.line_num, expected_col_count, len(row), ','.join(row)))
185-
# Check for dangling commma
186-
if row[-1] == '':
187-
raise CSVError("%s:%d Dangling comma in input. ('%s')"
188-
% (infile.name, reader.line_num, ','.join(row)))
189-
rel = Relation(row)
190-
RELATION_COUNT += 1
191-
outfile.write(rel.to_binary())
192-
if relations_have_properties:
193-
props = pack_props(row[2:])
194-
for prop in props:
195-
outfile.write(prop.to_binary())
192+
relfile = RelationFile(in_csv)
193+
relfile.process_entities()
194+
relfiles.append(relfile)
195+
return relfiles
196196

197197

198198
def help():
@@ -210,27 +210,20 @@ def help():
210210
@click.option('--relations', '-r', multiple=True, help='path to relation csv file')
211211

212212
def bulk_insert(graph, host, port, password, nodes, relations):
213-
global WORKING_DIR
214-
WORKING_DIR += graph
215-
try:
216-
os.mkdir(WORKING_DIR)
217-
except OSError as e:
218-
if e.errno != errno.EEXIST:
219-
raise
220-
process_node_csvs(nodes)
213+
nodefiles = process_node_csvs(nodes)
221214

222215
if relations:
223-
process_relation_csvs(relations)
216+
relfiles = process_relation_csvs(relations)
217+
218+
args = [graph, NODE_COUNT, RELATION_COUNT, "NODES"] + [e.to_binary() for e in nodefiles]
224219

225-
args = [graph, NODE_COUNT, RELATION_COUNT, "NODES"] + NODEFILES
226220
if RELATION_COUNT > 0:
227-
args += ["RELATIONS"] + RELFILES
221+
args += ["RELATIONS"] + [e.to_binary() for e in relfiles]
228222

229223
redis_client = redis.StrictRedis(host=host, port=port, password=password)
230224
# print(args)
231225
result = redis_client.execute_command("GRAPH.BULK", *args)
232226
print(result)
233-
# TODO Delete working dir
234227

235228
if __name__ == '__main__':
236229
bulk_insert()

0 commit comments

Comments
 (0)