|
4 | 4 | from random import choice, random, shuffle |
5 | 5 | from typing import Dict, List, Tuple |
6 | 6 |
|
| 7 | +import torch |
| 8 | + |
7 | 9 | from neat.aggregations import AggregationFunctionSet |
8 | 10 | from neat.config import ConfigParameter, write_pretty_params |
9 | 11 | from neat.graphs import creates_cycle, required_for_output |
10 | 12 |
|
11 | 13 | from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute |
12 | 14 | from computation_graphs.functions.activation import * |
13 | 15 | from computation_graphs.functions.aggregation import * |
14 | | -from genes import ConnectionGene, NodeGene |
| 16 | +from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX |
15 | 17 |
|
16 | 18 |
|
17 | 19 | class OptimizerGenomeConfig(object): |
@@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config): |
191 | 193 | # Homologous gene: combine genes from both parents. |
192 | 194 | self.nodes[key] = ng1.crossover(ng2) |
193 | 195 |
|
| 196 | + |
| 197 | + |
194 | 198 | def mutate(self, config): |
195 | 199 | """Mutates this genome.""" |
196 | 200 |
|
@@ -409,6 +413,42 @@ def get_pruned_copy(self, genome_config): |
409 | 413 | new_genome.connections = used_connection_genes |
410 | 414 | return new_genome |
411 | 415 |
|
| 416 | + def compile_optimizer(self, genome_config): |
| 417 | + """Compile this genome into a TorchScript optimizer.""" |
| 418 | + from graph_builder import rebuild_and_script |
| 419 | + |
| 420 | + if self.graph_dict is None: |
| 421 | + node_ids = sorted(self.nodes.keys()) |
| 422 | + node_types = [] |
| 423 | + node_attributes = [] |
| 424 | + for nid in node_ids: |
| 425 | + node = self.nodes[nid] |
| 426 | + idx = NODE_TYPE_TO_INDEX.get(node.node_type) |
| 427 | + if idx is None: |
| 428 | + raise KeyError(f"Unknown node_type {node.node_type!r}") |
| 429 | + node_types.append(idx) |
| 430 | + node_attributes.append(node.dynamic_attributes) |
| 431 | + node_types = torch.tensor(node_types, dtype=torch.long) |
| 432 | + |
| 433 | + edges = [] |
| 434 | + for (src, dst), conn in self.connections.items(): |
| 435 | + if conn.enabled and src in node_ids and dst in node_ids: |
| 436 | + local_src = node_ids.index(src) |
| 437 | + local_dst = node_ids.index(dst) |
| 438 | + edges.append([local_src, local_dst]) |
| 439 | + if edges: |
| 440 | + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() |
| 441 | + else: |
| 442 | + edge_index = torch.empty((2, 0), dtype=torch.long) |
| 443 | + self.graph_dict = { |
| 444 | + "node_types": node_types, |
| 445 | + "edge_index": edge_index, |
| 446 | + "node_attributes": node_attributes, |
| 447 | + } |
| 448 | + |
| 449 | + self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key) |
| 450 | + self.optimizer_path = None |
| 451 | + |
412 | 452 | def add_node(self, node_type: str, activation, aggregation) -> NodeGene: |
413 | 453 | if activation is None and aggregation is None: |
414 | 454 | print("WARNING: node added without any operation") |
|
0 commit comments