Skip to content

Commit 117239f

Browse files
committed
Remove globals
1 parent 7150e88 commit 117239f

File tree

1 file changed

+52
-47
lines changed

1 file changed

+52
-47
lines changed

bulk_insert.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import redis
55
import click
66

7-
NODE_COUNT = 0
8-
RELATION_COUNT = 0
9-
NODE_DICT = {}
10-
117
# Custom error class for invalid inputs
128
class CSVError(Exception):
139
pass
@@ -67,13 +63,12 @@ def __init__(self, filename):
6763
self.entity_str = os.path.splitext(os.path.basename(filename))[0].encode("ascii")
6864
# Input file handling
6965
self.infile = open(filename, 'rt')
70-
# Initialize CSV reader that ignores leading whitespace in each field and does not modify input quote characters
66+
# Initialize CSV reader that ignores leading whitespace in each field
67+
# and does not modify input quote characters
7168
self.reader = csv.reader(self.infile, skipinitialspace=True, quoting=csv.QUOTE_NONE)
7269

7370
self.prop_offset = 0 # Starting index of properties in row
74-
self.expected_col_count = 0
7571
self.prop_count = 0 # Number of properties per entity
76-
7772
self.entity_count = 0 # Total number of entities
7873

7974
self.packed_header = ""
@@ -101,97 +96,98 @@ def to_binary(self):
10196
return self.packed_header + b''.join(self.entities)
10297

10398
class NodeFile(EntityFile):
104-
def __init__(self, infile):
99+
def __init__(self, infile, node_dict, initial_node_count):
105100
super(NodeFile, self).__init__(infile)
106-
self.process_header()
107-
self.process_entities()
101+
expected_col_count = self.process_header()
102+
self.process_entities(node_dict, initial_node_count, expected_col_count)
108103
self.infile.close()
109104

110105
def process_header(self):
111106
# Header format:
112107
# source identifier, dest identifier, properties[0..n]
113108
header = next(self.reader)
114-
self.expected_col_count = len(header)
109+
expected_col_count = len(header)
115110
# If identifier field begins with an underscore, don't add it as a property.
116111
if header[0][0] == '_':
117112
self.prop_offset = 1
118113
self.packed_header = self.pack_header(header)
114+
return expected_col_count
119115

120-
def process_entities(self):
121-
global NODE_DICT
122-
global NODE_COUNT
123-
116+
def process_entities(self, node_dict, initial_node_count, expected_col_count):
124117
for row in self.reader:
125118
# Expect all entities to have the same property count
126-
if len(row) != self.expected_col_count:
119+
if len(row) != expected_col_count:
127120
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
128-
% (self.infile.name, self.reader.line_num, self.expected_col_count, len(row), ','.join(row)))
121+
% (self.infile.name, self.reader.line_num, expected_col_count, len(row), ','.join(row)))
129122
# Check for dangling commma
130123
if row[-1] == ',':
131124
raise CSVError("%s:%d Dangling comma in input. ('%s')"
132125
% (self.infile.name, self.reader.line_num, ','.join(row)))
133-
# Add identifier->NodeID pair to dictionary
134-
if row[0] in NODE_DICT:
135-
print("Node identifier '%s' was used multiple times - second occurrence at %s:%d" % (row[0], self.infile.name, self.reader.line_num))
136-
NODE_DICT[row[0]] = NODE_COUNT
137-
NODE_COUNT += 1
126+
# Add identifier->ID pair to dictionary if we are building relations
127+
if node_dict is not None:
128+
if row[0] in node_dict:
129+
print("Node identifier '%s' was used multiple times - second occurrence at %s:%d"
130+
% (row[0], self.infile.name, self.reader.line_num))
131+
node_dict[row[0]] = initial_node_count + self.entity_count
138132
self.entity_count += 1
139133
self.entities.append(self.pack_props(row))
140134

141135

142136
class RelationFile(EntityFile):
143-
def __init__(self, infile):
137+
def __init__(self, infile, node_dict):
144138
super(RelationFile, self).__init__(infile)
145-
self.process_header()
146-
self.process_entities()
139+
expected_col_count = self.process_header()
140+
self.process_entities(node_dict, expected_col_count)
147141
self.infile.close()
148142

149143
def process_header(self):
150144
# Header format:
151145
# source identifier, dest identifier, properties[0..n]
152146
header = next(self.reader)
153147
# Assume rectangular CSVs
154-
self.expected_col_count = len(header)
155-
self.prop_count = self.expected_col_count - 2
148+
expected_col_count = len(header)
149+
self.prop_count = expected_col_count - 2
156150
if self.prop_count < 0:
157-
print("Relation file '%s' should have at least 2 elements in header line." % (self.infile.name))
158-
return # TODO return?
151+
raise CSVError("Relation file '%s' should have at least 2 elements in header line."
152+
% (self.infile.name))
159153

160154
self.prop_offset = 2
161155
self.packed_header = self.pack_header(header) # skip src and dest identifiers
156+
return expected_col_count
162157

163-
def process_entities(self):
164-
global RELATION_COUNT
158+
def process_entities(self, node_dict, expected_col_count):
165159
for row in self.reader:
166160
# Each row should have the same number of fields
167-
if len(row) != self.expected_col_count:
161+
if len(row) != expected_col_count:
168162
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
169-
% (self.infile.name, self.reader.line_num, self.expected_col_count, len(row), ','.join(row)))
163+
% (self.infile.name, self.reader.line_num, expected_col_count, len(row), ','.join(row)))
170164
# Check for dangling commma
171165
if row[-1] == '':
172166
raise CSVError("%s:%d Dangling comma in input. ('%s')"
173167
% (self.infile.name, self.reader.line_num, ','.join(row)))
174168

175-
src = NODE_DICT[row[0]]
176-
dest = NODE_DICT[row[1]]
169+
src = node_dict[row[0]]
170+
dest = node_dict[row[1]]
177171
fmt = "=QQ" # 8-byte unsigned ints for src and dest
178172
self.entity_count += 1
179-
RELATION_COUNT += 1
180173
self.entities.append(struct.pack(fmt, src, dest) + self.pack_props(row))
181174

182-
def process_node_csvs(csvs):
175+
def process_node_csvs(csvs, node_dict):
183176
nodefiles = []
177+
node_count = 0
184178
for in_csv in csvs:
185-
nodefiles.append(NodeFile(in_csv))
186-
return nodefiles
179+
nodefile = NodeFile(in_csv, node_dict, node_count)
180+
node_count += nodefile.entity_count
181+
nodefiles.append(nodefile)
182+
return nodefiles, node_count
187183

188-
def process_relation_csvs(csvs):
184+
def process_relation_csvs(csvs, node_dict):
189185
relfiles = []
190186
for in_csv in csvs:
191-
relfiles.append(RelationFile(in_csv))
187+
rel = RelationFile(in_csv, node_dict)
188+
relfiles.append(rel)
192189
return relfiles
193190

194-
195191
def help():
196192
pass
197193

@@ -209,15 +205,24 @@ def help():
209205
@click.option('--max_token_count', '-t', default=1024, help='max number of tokens to send per query (default 1024)')
210206
@click.option('--max_buffer_size', '-b', default=4086, help='maximum buffer size in megabytes (default 4086)')
211207

212-
def bulk_insert(graph, host, port, password, nodes, relations, max_token_count, max_buffer_size):
213-
nodefiles = process_node_csvs(nodes)
208+
def bulk_insert(graph, host, port, password, nodes, relations):
209+
# Create a node dictionary if we're building relations and as such require unique identifiers
210+
if relations:
211+
node_dict = {}
212+
else:
213+
node_dict = None
214+
node_count = 0
215+
nodefiles, node_count = process_node_csvs(nodes, node_dict)
214216

215217
if relations:
216-
relfiles = process_relation_csvs(relations)
218+
relfiles = process_relation_csvs(relations, node_dict)
219+
relation_count = sum(e.entity_count for e in relfiles)
220+
else:
221+
relation_count = 0
217222

218-
args = [graph, "BEGIN", NODE_COUNT, RELATION_COUNT, "NODES"] + [e.to_binary() for e in nodefiles]
223+
args = [graph, "BEGIN", node_count, relation_count, "NODES"] + [e.to_binary() for e in nodefiles]
219224

220-
if RELATION_COUNT > 0:
225+
if relation_count > 0:
221226
args += ["RELATIONS"] + [e.to_binary() for e in relfiles]
222227

223228
redis_client = redis.StrictRedis(host=host, port=port, password=password)

0 commit comments

Comments
 (0)