|
6 | 6 | from timeit import default_timer as timer |
7 | 7 | import redis |
8 | 8 | import click |
| 9 | +import json |
9 | 10 |
|
10 | 11 | # Global variables |
11 | 12 | CONFIGS = None # thresholds for batching Redis queries |
12 | 13 | NODE_DICT = {} # global node dictionary |
13 | 14 | TOP_NODE_ID = 0 # next ID to assign to a node |
14 | 15 | QUERY_BUF = None # Buffer for query being constructed |
15 | 16 |
|
| 17 | +FIELD_TYPES = None |
| 18 | + |
16 | 19 | # Custom error class for invalid inputs |
17 | 20 | class CSVError(Exception): |
18 | 21 | pass |
@@ -155,8 +158,13 @@ def pack_header(self, header): |
155 | 158 | # Convert a list of properties into a binary string |
156 | 159 | def pack_props(self, line): |
157 | 160 | props = [] |
158 | | - for field in line[self.prop_offset:]: |
159 | | - props.append(prop_to_binary(field)) |
| 161 | + for num, field in enumerate(line[self.prop_offset:]): |
| 162 | + try : |
| 163 | + FIELD_TYPES[self.entity_str][num] |
| 164 | + except : |
| 165 | + props.append(prop_to_binary(field, None)) |
| 166 | + else : |
| 167 | + props.append(prop_to_binary(field, FIELD_TYPES[self.entity_str][num])) |
160 | 168 |
|
161 | 169 | return b''.join(p for p in props) |
162 | 170 |
|
@@ -278,31 +286,39 @@ def process_entities(self, expected_col_count): |
278 | 286 |
|
279 | 287 | # Convert a single CSV property field into a binary stream. |
280 | 288 | # Supported property types are string, numeric, boolean, and NULL. |
281 | | -def prop_to_binary(prop_str): |
| 289 | +# type is either Type.NUMERIC, Type.BOOL or Type.STRING, and explicitly sets the value to this type if possible |
| 290 | +def prop_to_binary(prop_str, type): |
282 | 291 | # All format strings start with an unsigned char to represent our Type enum |
283 | 292 | format_str = "=B" |
284 | 293 | if not prop_str: |
285 | 294 | # An empty field indicates a NULL property |
286 | 295 | return struct.pack(format_str, Type.NULL) |
287 | 296 |
|
288 | 297 | # If field can be cast to a float, allow it |
289 | | - try: |
290 | | - numeric_prop = float(prop_str) |
291 | | - return struct.pack(format_str + "d", Type.NUMERIC, numeric_prop) |
292 | | - except: |
293 | | - pass |
294 | | - |
295 | | - # If field is 'false' or 'true', it is a boolean |
296 | | - if prop_str.lower() == 'false': |
297 | | - return struct.pack(format_str + '?', Type.BOOL, False) |
298 | | - elif prop_str.lower() == 'true': |
299 | | - return struct.pack(format_str + '?', Type.BOOL, True) |
300 | | - |
301 | | - # If we've reached this point, the property is a string |
302 | | - encoded_str = str.encode(prop_str) # struct.pack requires bytes objects as arguments |
303 | | - # Encoding len+1 adds a null terminator to the string |
304 | | - format_str += "%ds" % (len(encoded_str) + 1) |
305 | | - return struct.pack(format_str, Type.STRING, encoded_str) |
| 298 | + |
| 299 | + if type == None or type == Type.NUMERIC: |
| 300 | + try: |
| 301 | + numeric_prop = float(prop_str) |
| 302 | + return struct.pack(format_str + "d", Type.NUMERIC, numeric_prop) |
| 303 | + except: |
| 304 | + pass |
| 305 | + |
| 306 | + if type == None or type == Type.BOOL: |
| 307 | + # If field is 'false' or 'true', it is a boolean |
| 308 | + if prop_str.lower() == 'false': |
| 309 | + return struct.pack(format_str + '?', Type.BOOL, False) |
| 310 | + elif prop_str.lower() == 'true': |
| 311 | + return struct.pack(format_str + '?', Type.BOOL, True) |
| 312 | + |
| 313 | + if type == None or type == Type.STRING: |
| 314 | + # If we've reached this point, the property is a string |
| 315 | + encoded_str = str.encode(prop_str) # struct.pack requires bytes objects as arguments |
| 316 | + # Encoding len+1 adds a null terminator to the string |
| 317 | + format_str += "%ds" % (len(encoded_str) + 1) |
| 318 | + return struct.pack(format_str, Type.STRING, encoded_str) |
| 319 | + |
| 320 | + ## if it hasn't returned by this point, it is trying to set it to a type that it can't adopt |
| 321 | + raise Exception("unable to parse [" + prop_str + "] with type ["+repr(type)+"]") |
306 | 322 |
|
307 | 323 | # For each node input file, validate contents and convert to binary format. |
308 | 324 | # If any buffer limits have been reached, flush all enqueued inserts to Redis. |
@@ -336,25 +352,30 @@ def process_entity_csvs(cls, csvs, separator): |
336 | 352 | @click.option('--max-token-count', '-c', default=1024, help='max number of processed CSVs to send per query (default 1024)') |
337 | 353 | @click.option('--max-buffer-size', '-b', default=2048, help='max buffer size in megabytes (default 2048)') |
338 | 354 | @click.option('--max-token-size', '-t', default=500, help='max size of each token in megabytes (default 500, max 512)') |
339 | | -@click.option('--quote-minimal/--no-quote-minimal', '-q/-d', default=False, help='only quote those fields which contain special characters such as delimiter, quotechar or any of the characters in lineterminator') |
| 355 | +@click.option('--quote', '-q', default=3, help='the quoting format used in the CSV file. QUOTE_MINIMAL=0,QUOTE_ALL=1,QUOTE_NONNUMERIC=2,QUOTE_NONE=3') |
| 356 | +@click.option('--field-types', '-f', default=None, help='json to set explicit types for each field, format {<label>:[<col1 type>, <col2 type> ...]} where type can be 0(null),1(bool),2(numeric),3(string)') |
340 | 357 | @click.option('--skip-invalid-nodes', '-s', default=False, is_flag=True, help='ignore nodes that use previously defined IDs') |
341 | 358 | @click.option('--skip-invalid-edges', '-e', default=False, is_flag=True, help='ignore invalid edges, print an error message and continue loading (True), or stop loading after an edge loading failure (False)') |
342 | 359 |
|
343 | 360 |
|
344 | | -def bulk_insert(graph, host, port, password, nodes, relations, separator, max_token_count, max_buffer_size, max_token_size, quote_minimal, skip_invalid_nodes, skip_invalid_edges): |
| 361 | +def bulk_insert(graph, host, port, password, nodes, relations, separator, max_token_count, max_buffer_size, max_token_size, quote, field_types, skip_invalid_nodes, skip_invalid_edges): |
345 | 362 | global CONFIGS |
346 | 363 | global NODE_DICT |
347 | 364 | global TOP_NODE_ID |
348 | 365 | global QUERY_BUF |
349 | 366 | global QUOTING |
| 367 | + global FIELD_TYPES |
350 | 368 |
|
351 | 369 | if sys.version_info[0] < 3: |
352 | 370 | raise Exception("Python 3 is required for the RedisGraph bulk loader.") |
353 | 371 |
|
354 | | - if quote_minimal: |
355 | | - QUOTING=csv.QUOTE_MINIMAL |
356 | | - else: |
357 | | - QUOTING=csv.QUOTE_NONE |
| 372 | + if field_types != None: |
| 373 | + try : |
| 374 | + FIELD_TYPES = json.loads(field_types) |
| 375 | + except: |
| 376 | + raise Exception("Problem parsing field-types. Use the format {<label>:[<col1 type>, <col2 type> ...]} where type can be 0(null),1(bool),2(numeric),3(string) ") |
| 377 | + |
| 378 | + QUOTING=int(quote) |
358 | 379 |
|
359 | 380 | TOP_NODE_ID = 0 # reset global ID variable (in case we are calling bulk_insert from unit tests) |
360 | 381 | CONFIGS = Configs(max_token_count, max_buffer_size, max_token_size, skip_invalid_nodes, skip_invalid_edges) |
|
0 commit comments