|
| 1 | +import redis |
| 2 | +import csv |
| 3 | +import os |
| 4 | +import click |
| 5 | + |
| 6 | +# Global variables (can refactor into arguments later) |
| 7 | +max_tokens = 1024 * 1024 |
| 8 | +graphname = None |
| 9 | +redis_client = None |
| 10 | + |
| 11 | +# Custom error class for invalid inputs |
| 12 | +class CSVError(Exception): |
| 13 | + pass |
| 14 | + |
| 15 | +# Argument is the container class for the metadata parameters to be sent in a Redis query. |
| 16 | +# An object of this type is comprised of the type of the inserted entity ("NODES" or "RELATIONS"), |
| 17 | +# the number of entities, and one Descriptor for each entity. |
| 18 | +# The 'unroll' method generates this sequence as a list to be passed as part of a Redis query |
| 19 | +# (followed by the entities themselves). |
| 20 | +class Argument: |
| 21 | + def __init__(self, argtype): |
| 22 | + self.argtype = argtype |
| 23 | + self.descriptors = [] |
| 24 | + self.entities_created = 0 |
| 25 | + self.queries_processed = 0 |
| 26 | + self.total_entities = 0 |
| 27 | + self.descriptors_in_query = 0 |
| 28 | + |
| 29 | + def reset_tokens(self): |
| 30 | + for descriptor in self.descriptors: |
| 31 | + descriptor.reset_tokens() |
| 32 | + |
| 33 | + def remove_descriptors(self, delete_count): |
| 34 | + for descriptor in self.descriptors[:delete_count]: |
| 35 | + print('Finished inserting "%s".' % descriptor.name) |
| 36 | + del self.descriptors[:delete_count] |
| 37 | + |
| 38 | + def unroll(self): |
| 39 | + ret = [self.argtype, self.pending_inserts(), self.descriptors_in_query] |
| 40 | + for descriptor in self.descriptors[0:self.descriptors_in_query]: |
| 41 | + # Don't include a descriptor unless we are inserting its entities. |
| 42 | + # This can occur if the addition of a descriptor and its first entity caused |
| 43 | + # the token count to exceed the max allowed |
| 44 | + if descriptor.pending_inserts == 0: |
| 45 | + ret[2] -= 1 |
| 46 | + else: |
| 47 | + ret += descriptor.unroll() |
| 48 | + return ret |
| 49 | + |
| 50 | + def token_count(self): |
| 51 | + # ["NODES"/"RELATIONS", number of entities to insert, number of different entity types, list of type descriptors] |
| 52 | + return 3 + sum(descriptor.token_count() for descriptor in self.descriptors[0:self.descriptors_in_query]) |
| 53 | + |
| 54 | + def pending_inserts(self): |
| 55 | + return sum(desc.pending_inserts for desc in self.descriptors) |
| 56 | + |
| 57 | + def print_insertion_done(self): |
| 58 | + print("%d %s created in %d queries." % (self.entities_created, self.argtype.lower(), self.queries_processed)) |
| 59 | + |
| 60 | + def batch_insert_descriptors(self): |
| 61 | + entities = [] # Tokens pending insertion from CSV rows |
| 62 | + token_count = 5 # Prefix tokens: "GRAPH.BULK [graphname] ["LABELS"/"RELATIONS"] [entity_count] [# of descriptors]" |
| 63 | + for desc in self.descriptors: |
| 64 | + desc.print_insertion_start() |
| 65 | + self.descriptors_in_query += 1 |
| 66 | + token_count += desc.token_count() |
| 67 | + for row in desc.reader: |
| 68 | + token_count += desc.attribute_count # All rows have the same length |
| 69 | + if token_count > max_tokens: |
| 70 | + # max_tokens has been reached; submit all but the most recent node |
| 71 | + query_redis(self, entities) |
| 72 | + desc.entities_created += desc.pending_inserts |
| 73 | + # Reset values post-query |
| 74 | + self.reset_tokens() |
| 75 | + entities = [] |
| 76 | + |
| 77 | + # Remove descriptors that have no further pending inserts |
| 78 | + # (all but the current). |
| 79 | + self.remove_descriptors(self.descriptors_in_query - 1) |
| 80 | + self.descriptors_in_query = 1 |
| 81 | + if desc.entities_created > 0: |
| 82 | + desc.print_progress() |
| 83 | + # Following an insertion, token_count is set to accommodate "GRAPH.BULK", graphname, |
| 84 | + # all labels and attributes, plus the individual tokens from the uninserted current row |
| 85 | + token_count = 2 + self.token_count() + desc.attribute_count |
| 86 | + desc.pending_inserts += 1 |
| 87 | + entities += row |
| 88 | + |
| 89 | + # Insert all remaining nodes |
| 90 | + query_redis(self, entities) |
| 91 | + self.remove_descriptors(self.descriptors_in_query) |
| 92 | + self.print_insertion_done() |
| 93 | + |
| 94 | +# The Descriptor class holds the name and element counts for an individual relationship or node label. |
| 95 | +# After the contents of an Argument instance have been submitted to Redis in a GRAPH.BULK query, |
| 96 | +# the contents of each descriptor contained in the query are unrolled. |
| 97 | +# The `unroll` methods are unique to the data type. |
| 98 | +class Descriptor: |
| 99 | + def __init__(self, csvfile): |
| 100 | + # A Label or Relationship name is set by the CSV file name |
| 101 | + # TODO validate name string |
| 102 | + self.name = os.path.splitext(os.path.basename(csvfile))[0] |
| 103 | + self.csvfile = open(csvfile, "rt") |
| 104 | + self.reader = csv.reader(self.csvfile) |
| 105 | + self.total_entities = 0 |
| 106 | + self.entities_created = 0 |
| 107 | + self.pending_inserts = 0 |
| 108 | + |
| 109 | + def print_progress(self): |
| 110 | + print('%.2f%% (%d / %d) of "%s" inserted.' % (float(self.entities_created) * 100 / self.total_entities, |
| 111 | + self.entities_created, |
| 112 | + self.total_entities, |
| 113 | + self.name)) |
| 114 | + |
| 115 | + def reset_tokens(self): |
| 116 | + self.pending_inserts = 0 |
| 117 | + |
| 118 | +# A LabelDescriptor consists of a name string, an entity count, and a series of |
| 119 | +# attributes (derived from the header row of the label CSV). |
| 120 | +# As a query string, a LabelDescriptor is printed in the format: |
| 121 | +# [label name, entity count, attribute count, attributes[0..n]]] |
| 122 | +class LabelDescriptor(Descriptor): |
| 123 | + def __init__(self, csvfile): |
| 124 | + Descriptor.__init__(self, csvfile) |
| 125 | + # The first row of a label CSV contains its attributes |
| 126 | + self.attributes = next(self.reader) |
| 127 | + self.attribute_count = len(self.attributes) |
| 128 | + self.validate_csv() |
| 129 | + # Reset input CSV, then skip header line |
| 130 | + self.csvfile.seek(0) |
| 131 | + next(self.reader) |
| 132 | + |
| 133 | + def validate_csv(self): |
| 134 | + # Expect all rows to have the same column count as the header. |
| 135 | + expected_col_count = self.attribute_count |
| 136 | + for row in self.reader: |
| 137 | + # Raise an exception if the wrong number of columns are found |
| 138 | + if len(row) != expected_col_count: |
| 139 | + raise CSVError ("%s:%d Expected %d columns, encountered %d ('%s')" |
| 140 | + % (self.csvfile, self.reader.line_num, expected_col_count, len(row), ','.join(row))) |
| 141 | + if (row[-1] == ''): |
| 142 | + raise CSVError ("%s:%d Dangling comma in input. ('%s')" |
| 143 | + % (self.csvfile, self.reader.line_num, ','.join(row))) |
| 144 | + # Subtract 1 from each file's entity count to compensate for the header. |
| 145 | + self.total_entities = self.reader.line_num - 1 |
| 146 | + |
| 147 | + def token_count(self): |
| 148 | + # Labels have a token for name, attribute count, pending insertion count, plus N tokens for the individual attributes. |
| 149 | + return 3 + self.attribute_count |
| 150 | + |
| 151 | + def unroll(self): |
| 152 | + return [self.name, self.pending_inserts, self.attribute_count] + self.attributes |
| 153 | + |
| 154 | + def print_insertion_start(self): |
| 155 | + print('Inserting Label "%s" - %d nodes' % (self.name, self.total_entities)) |
| 156 | + |
| 157 | +# A RelationDescriptor consists of a name string and a relationship count. |
| 158 | +# As a query string, a RelationDescriptor is printed in the format: |
| 159 | +# [relation name, relation count] |
| 160 | +class RelationDescriptor(Descriptor): |
| 161 | + def __init__(self, csvfile): |
| 162 | + Descriptor.__init__(self, csvfile) |
| 163 | + self.attribute_count = 2 |
| 164 | + self.validate_csv() |
| 165 | + self.csvfile.seek(0) |
| 166 | + |
| 167 | + def validate_csv(self): |
| 168 | + for row in self.reader: |
| 169 | + # Each row should have two columns (a source and dest ID) |
| 170 | + if len(row) != 2: |
| 171 | + raise CSVError ("%s:%d Expected 2 columns, encountered %d ('%s')" |
| 172 | + % (self.csvfile, self.reader.line_num, len(row), ','.join(row))) |
| 173 | + for elem in row: |
| 174 | + # Raise an exception if an element cannot be read as an integer |
| 175 | + try: |
| 176 | + int(elem) |
| 177 | + except: |
| 178 | + raise CSVError ("%s:%d Token '%s' was not a node ID)" |
| 179 | + % (self.csvfile, self.reader.line_num, elem)) |
| 180 | + self.total_entities = self.reader.line_num |
| 181 | + |
| 182 | + def token_count(self): |
| 183 | + # Relations have 2 tokens: name and pending insertion count. |
| 184 | + return 2 |
| 185 | + |
| 186 | + def unroll(self): |
| 187 | + return [self.name, self.pending_inserts] |
| 188 | + |
| 189 | + def print_insertion_start(self): |
| 190 | + print('Inserting relation "%s" - %d edges' % (self.name, self.total_entities)) |
| 191 | + |
| 192 | +# Issue single Redis query to allocate space for graph |
| 193 | +def allocate_graph(node_count, relation_count): |
| 194 | + cmd = ["GRAPH.BULK", graphname, "BEGIN", node_count, relation_count] |
| 195 | + result = redis_client.execute_command(*cmd) |
| 196 | + print(result) |
| 197 | + |
| 198 | +def finalize_graph(): |
| 199 | + cmd = ["GRAPH.BULK", graphname, "END"] |
| 200 | + result = redis_client.execute_command(*cmd) |
| 201 | + print(result) |
| 202 | + |
| 203 | +def query_redis(metadata, entities): |
| 204 | + cmd = ["GRAPH.BULK", graphname] + metadata.unroll() + entities |
| 205 | + # Raise error if query doesn't contain entities |
| 206 | + if not entities: |
| 207 | + raise Exception ("Attempted to insert 0 tokens( '%s')." % (" ".join(str(e) for e in cmd))) |
| 208 | + |
| 209 | + # Send query to Redis client |
| 210 | + result = redis_client.execute_command(*cmd) |
| 211 | + stats = result.split(', '.encode()) |
| 212 | + metadata.entities_created += int(stats[0].split(' '.encode())[0]) |
| 213 | + metadata.entities_created += int(stats[1].split(' '.encode())[0]) |
| 214 | + metadata.queries_processed += 1 |
| 215 | + |
| 216 | +def build_descriptors(csvs, argtype): |
| 217 | + # Prepare container for all labels |
| 218 | + arg = Argument(argtype) |
| 219 | + |
| 220 | + # Generate a label descriptor from each given label CSV |
| 221 | + for f in csvs: |
| 222 | + # Better method for this? |
| 223 | + if (argtype == "NODES"): |
| 224 | + descriptor = LabelDescriptor(f) |
| 225 | + else: |
| 226 | + descriptor = RelationDescriptor(f) |
| 227 | + arg.descriptors.append(descriptor) |
| 228 | + arg.total_entities = sum(desc.total_entities for desc in arg.descriptors) |
| 229 | + return arg |
| 230 | + |
| 231 | +def help(): |
| 232 | + pass |
| 233 | + |
| 234 | +# Command-line arguments |
| 235 | +@click.command() |
| 236 | +@click.argument('graph') |
| 237 | +# Redis server connection settings |
| 238 | +@click.option('--host', '-h', default='127.0.0.1', help='Redis server host') |
| 239 | +@click.option('--port', '-p', default=6379, help='Redis server port') |
| 240 | +@click.option('--password', '-P', default=None, help='Redis server password') |
| 241 | +@click.option('--ssl', '-s', default=False, help='Server is SSL-enabled') |
| 242 | +# CSV file paths |
| 243 | +@click.option('--nodes', '-n', required=True, multiple=True, help='path to node csv file') |
| 244 | +@click.option('--relations', '-r', multiple=True, help='path to relation csv file') |
| 245 | +# Debug options |
| 246 | +@click.option('--max_buffer_size', '-m', default=1024*1024, help='(DEBUG ONLY) - max token count per Redis query') |
| 247 | + |
| 248 | +def bulk_insert(graph, host, port, password, ssl, nodes, relations, max_buffer_size): |
| 249 | + global graphname |
| 250 | + global redis_client |
| 251 | + global max_tokens |
| 252 | + graphname = graph |
| 253 | + redis_client = redis.StrictRedis(host=host, port=port, password=password, ssl=ssl) |
| 254 | + if max_buffer_size > max_tokens: |
| 255 | + print("Requested buffer size too large, capping queries at %d." % max_tokens) |
| 256 | + else: |
| 257 | + max_tokens = max_buffer_size |
| 258 | + |
| 259 | + # Iterate over label CSVs to validate inputs and build label descriptors. |
| 260 | + print("Building label descriptors...") |
| 261 | + label_descriptors = build_descriptors(nodes, "NODES") |
| 262 | + relation_count = 0 |
| 263 | + if relations: |
| 264 | + relation_descriptors = build_descriptors(relations, "RELATIONS") |
| 265 | + relation_count = relation_descriptors.total_entities |
| 266 | + |
| 267 | + # Send prefix tokens to RedisGraph |
| 268 | + # This could be also done as part of the first query, |
| 269 | + # but would somewhat complicate counting logic |
| 270 | + allocate_graph(label_descriptors.total_entities, relation_count) |
| 271 | + |
| 272 | + # Process input CSVs and commit their contents to RedisGraph |
| 273 | + # Possibly make this a method on Argument |
| 274 | + label_descriptors.batch_insert_descriptors() |
| 275 | + if relations: |
| 276 | + relation_descriptors.batch_insert_descriptors() |
| 277 | + finalize_graph() |
| 278 | + |
| 279 | +if __name__ == '__main__': |
| 280 | + bulk_insert() |
| 281 | + |
0 commit comments