|
4 | 4 | from random import choice, random, shuffle |
5 | 5 | from typing import Dict, List, Tuple |
6 | 6 |
|
| 7 | +import torch |
| 8 | + |
7 | 9 | from neat.aggregations import AggregationFunctionSet |
8 | 10 | from neat.config import ConfigParameter, write_pretty_params |
9 | 11 | from neat.graphs import creates_cycle, required_for_output |
@@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config): |
191 | 193 | # Homologous gene: combine genes from both parents. |
192 | 194 | self.nodes[key] = ng1.crossover(ng2) |
193 | 195 |
|
| 196 | + |
| 197 | + |
194 | 198 | def mutate(self, config): |
195 | 199 | """Mutates this genome.""" |
196 | 200 |
|
@@ -409,6 +413,19 @@ def get_pruned_copy(self, genome_config): |
409 | 413 | new_genome.connections = used_connection_genes |
410 | 414 | return new_genome |
411 | 415 |
|
| 416 | + def compile_optimizer(self, genome_config): |
| 417 | + """Compile this genome into a TorchScript optimizer.""" |
| 418 | + from graph_builder import DynamicOptimizerModule |
| 419 | + |
| 420 | + if not self.connections: |
| 421 | + self.optimizer = None |
| 422 | + else: |
| 423 | + module = DynamicOptimizerModule( |
| 424 | + self, genome_config.input_keys, genome_config.output_keys, self.graph_dict |
| 425 | + ) |
| 426 | + self.optimizer = torch.jit.script(module) |
| 427 | + self.optimizer_path = None |
| 428 | + |
412 | 429 | def add_node(self, node_type: str, activation, aggregation) -> NodeGene: |
413 | 430 | if activation is None and aggregation is None: |
414 | 431 | print("WARNING: node added without any operation") |
|
0 commit comments