44import redis
55import click
66
7- NODE_COUNT = 0
8- RELATION_COUNT = 0
9- NODE_DICT = {}
10-
117# Custom error class for invalid inputs
128class CSVError (Exception ):
139 pass
@@ -67,13 +63,12 @@ def __init__(self, filename):
6763 self .entity_str = os .path .splitext (os .path .basename (filename ))[0 ].encode ("ascii" )
6864 # Input file handling
6965 self .infile = open (filename , 'rt' )
70- # Initialize CSV reader that ignores leading whitespace in each field and does not modify input quote characters
66+ # Initialize CSV reader that ignores leading whitespace in each field
67+ # and does not modify input quote characters
7168 self .reader = csv .reader (self .infile , skipinitialspace = True , quoting = csv .QUOTE_NONE )
7269
7370 self .prop_offset = 0 # Starting index of properties in row
74- self .expected_col_count = 0
7571 self .prop_count = 0 # Number of properties per entity
76-
7772 self .entity_count = 0 # Total number of entities
7873
7974 self .packed_header = ""
@@ -101,97 +96,98 @@ def to_binary(self):
10196 return self .packed_header + b'' .join (self .entities )
10297
10398class NodeFile (EntityFile ):
104- def __init__ (self , infile ):
99+ def __init__ (self , infile , node_dict , initial_node_count ):
105100 super (NodeFile , self ).__init__ (infile )
106- self .process_header ()
107- self .process_entities ()
101+ expected_col_count = self .process_header ()
102+ self .process_entities (node_dict , initial_node_count , expected_col_count )
108103 self .infile .close ()
109104
110105 def process_header (self ):
111106 # Header format:
112107 # source identifier, dest identifier, properties[0..n]
113108 header = next (self .reader )
114- self . expected_col_count = len (header )
109+ expected_col_count = len (header )
115110 # If identifier field begins with an underscore, don't add it as a property.
116111 if header [0 ][0 ] == '_' :
117112 self .prop_offset = 1
118113 self .packed_header = self .pack_header (header )
114+ return expected_col_count
119115
120- def process_entities (self ):
121- global NODE_DICT
122- global NODE_COUNT
123-
116+ def process_entities (self , node_dict , initial_node_count , expected_col_count ):
124117 for row in self .reader :
125118 # Expect all entities to have the same property count
126- if len (row ) != self . expected_col_count :
119+ if len (row ) != expected_col_count :
127120 raise CSVError ("%s:%d Expected %d columns, encountered %d ('%s')"
128- % (self .infile .name , self .reader .line_num , self . expected_col_count , len (row ), ',' .join (row )))
121+ % (self .infile .name , self .reader .line_num , expected_col_count , len (row ), ',' .join (row )))
129122 # Check for dangling commma
130123 if row [- 1 ] == ',' :
131124 raise CSVError ("%s:%d Dangling comma in input. ('%s')"
132125 % (self .infile .name , self .reader .line_num , ',' .join (row )))
133- # Add identifier->NodeID pair to dictionary
134- if row [0 ] in NODE_DICT :
135- print ("Node identifier '%s' was used multiple times - second occurrence at %s:%d" % (row [0 ], self .infile .name , self .reader .line_num ))
136- NODE_DICT [row [0 ]] = NODE_COUNT
137- NODE_COUNT += 1
126+ # Add identifier->ID pair to dictionary if we are building relations
127+ if node_dict is not None :
128+ if row [0 ] in node_dict :
129+ print ("Node identifier '%s' was used multiple times - second occurrence at %s:%d"
130+ % (row [0 ], self .infile .name , self .reader .line_num ))
131+ node_dict [row [0 ]] = initial_node_count + self .entity_count
138132 self .entity_count += 1
139133 self .entities .append (self .pack_props (row ))
140134
141135
142136class RelationFile (EntityFile ):
143- def __init__ (self , infile ):
137+ def __init__ (self , infile , node_dict ):
144138 super (RelationFile , self ).__init__ (infile )
145- self .process_header ()
146- self .process_entities ()
139+ expected_col_count = self .process_header ()
140+ self .process_entities (node_dict , expected_col_count )
147141 self .infile .close ()
148142
149143 def process_header (self ):
150144 # Header format:
151145 # source identifier, dest identifier, properties[0..n]
152146 header = next (self .reader )
153147 # Assume rectangular CSVs
154- self . expected_col_count = len (header )
155- self .prop_count = self . expected_col_count - 2
148+ expected_col_count = len (header )
149+ self .prop_count = expected_col_count - 2
156150 if self .prop_count < 0 :
157- print ("Relation file '%s' should have at least 2 elements in header line." % ( self . infile . name ))
158- return # TODO return?
151+ raise CSVError ("Relation file '%s' should have at least 2 elements in header line."
152+ % ( self . infile . name ))
159153
160154 self .prop_offset = 2
161155 self .packed_header = self .pack_header (header ) # skip src and dest identifiers
156+ return expected_col_count
162157
163- def process_entities (self ):
164- global RELATION_COUNT
158+ def process_entities (self , node_dict , expected_col_count ):
165159 for row in self .reader :
166160 # Each row should have the same number of fields
167- if len (row ) != self . expected_col_count :
161+ if len (row ) != expected_col_count :
168162 raise CSVError ("%s:%d Expected %d columns, encountered %d ('%s')"
169- % (self .infile .name , self .reader .line_num , self . expected_col_count , len (row ), ',' .join (row )))
163+ % (self .infile .name , self .reader .line_num , expected_col_count , len (row ), ',' .join (row )))
170164 # Check for dangling commma
171165 if row [- 1 ] == '' :
172166 raise CSVError ("%s:%d Dangling comma in input. ('%s')"
173167 % (self .infile .name , self .reader .line_num , ',' .join (row )))
174168
175- src = NODE_DICT [row [0 ]]
176- dest = NODE_DICT [row [1 ]]
169+ src = node_dict [row [0 ]]
170+ dest = node_dict [row [1 ]]
177171 fmt = "=QQ" # 8-byte unsigned ints for src and dest
178172 self .entity_count += 1
179- RELATION_COUNT += 1
180173 self .entities .append (struct .pack (fmt , src , dest ) + self .pack_props (row ))
181174
182- def process_node_csvs (csvs ):
175+ def process_node_csvs (csvs , node_dict ):
183176 nodefiles = []
177+ node_count = 0
184178 for in_csv in csvs :
185- nodefiles .append (NodeFile (in_csv ))
186- return nodefiles
179+ nodefile = NodeFile (in_csv , node_dict , node_count )
180+ node_count += nodefile .entity_count
181+ nodefiles .append (nodefile )
182+ return nodefiles , node_count
187183
188- def process_relation_csvs (csvs ):
184+ def process_relation_csvs (csvs , node_dict ):
189185 relfiles = []
190186 for in_csv in csvs :
191- relfiles .append (RelationFile (in_csv ))
187+ rel = RelationFile (in_csv , node_dict )
188+ relfiles .append (rel )
192189 return relfiles
193190
194-
195191def help ():
196192 pass
197193
@@ -209,15 +205,24 @@ def help():
209205@click .option ('--max_token_count' , '-t' , default = 1024 , help = 'max number of tokens to send per query (default 1024)' )
210206@click .option ('--max_buffer_size' , '-b' , default = 4086 , help = 'maximum buffer size in megabytes (default 4086)' )
211207
212- def bulk_insert (graph , host , port , password , nodes , relations , max_token_count , max_buffer_size ):
213- nodefiles = process_node_csvs (nodes )
208+ def bulk_insert (graph , host , port , password , nodes , relations ):
209+ # Create a node dictionary if we're building relations and as such require unique identifiers
210+ if relations :
211+ node_dict = {}
212+ else :
213+ node_dict = None
214+ node_count = 0
215+ nodefiles , node_count = process_node_csvs (nodes , node_dict )
214216
215217 if relations :
216- relfiles = process_relation_csvs (relations )
218+ relfiles = process_relation_csvs (relations , node_dict )
219+ relation_count = sum (e .entity_count for e in relfiles )
220+ else :
221+ relation_count = 0
217222
218- args = [graph , "BEGIN" , NODE_COUNT , RELATION_COUNT , "NODES" ] + [e .to_binary () for e in nodefiles ]
223+ args = [graph , "BEGIN" , node_count , relation_count , "NODES" ] + [e .to_binary () for e in nodefiles ]
219224
220- if RELATION_COUNT > 0 :
225+ if relation_count > 0 :
221226 args += ["RELATIONS" ] + [e .to_binary () for e in relfiles ]
222227
223228 redis_client = redis .StrictRedis (host = host , port = port , password = password )
0 commit comments