|
13 | 13 | from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute |
14 | 14 | from computation_graphs.functions.activation import * |
15 | 15 | from computation_graphs.functions.aggregation import * |
16 | | -from genes import ConnectionGene, NodeGene |
| 16 | +from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX |
17 | 17 |
|
18 | 18 |
|
19 | 19 | class OptimizerGenomeConfig(object): |
@@ -415,15 +415,38 @@ def get_pruned_copy(self, genome_config): |
415 | 415 |
|
416 | 416 | def compile_optimizer(self, genome_config): |
417 | 417 | """Compile this genome into a TorchScript optimizer.""" |
418 | | - from graph_builder import DynamicOptimizerModule |
419 | | - |
420 | | - if not self.connections: |
421 | | - self.optimizer = None |
422 | | - else: |
423 | | - module = DynamicOptimizerModule( |
424 | | - self, genome_config.input_keys, genome_config.output_keys, self.graph_dict |
425 | | - ) |
426 | | - self.optimizer = torch.jit.script(module) |
| 418 | + from graph_builder import rebuild_and_script |
| 419 | + |
| 420 | + if self.graph_dict is None: |
| 421 | + node_ids = sorted(self.nodes.keys()) |
| 422 | + node_types = [] |
| 423 | + node_attributes = [] |
| 424 | + for nid in node_ids: |
| 425 | + node = self.nodes[nid] |
| 426 | + idx = NODE_TYPE_TO_INDEX.get(node.node_type) |
| 427 | + if idx is None: |
| 428 | + raise KeyError(f"Unknown node_type {node.node_type!r}") |
| 429 | + node_types.append(idx) |
| 430 | + node_attributes.append(node.dynamic_attributes) |
| 431 | + node_types = torch.tensor(node_types, dtype=torch.long) |
| 432 | + |
| 433 | + edges = [] |
| 434 | + for (src, dst), conn in self.connections.items(): |
| 435 | + if conn.enabled and src in node_ids and dst in node_ids: |
| 436 | + local_src = node_ids.index(src) |
| 437 | + local_dst = node_ids.index(dst) |
| 438 | + edges.append([local_src, local_dst]) |
| 439 | + if edges: |
| 440 | + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() |
| 441 | + else: |
| 442 | + edge_index = torch.empty((2, 0), dtype=torch.long) |
| 443 | + self.graph_dict = { |
| 444 | + "node_types": node_types, |
| 445 | + "edge_index": edge_index, |
| 446 | + "node_attributes": node_attributes, |
| 447 | + } |
| 448 | + |
| 449 | + self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key) |
427 | 450 | self.optimizer_path = None |
428 | 451 |
|
429 | 452 | def add_node(self, node_type: str, activation, aggregation) -> NodeGene: |
|
0 commit comments