Skip to content

Commit 5ddb3c3

Browse files
committed
Compile optimizer after crossover
1 parent ea294a7 commit 5ddb3c3

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

genome.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from random import choice, random, shuffle
55
from typing import Dict, List, Tuple
66

7+
import torch
8+
79
from neat.aggregations import AggregationFunctionSet
810
from neat.config import ConfigParameter, write_pretty_params
911
from neat.graphs import creates_cycle, required_for_output
@@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config):
191193
# Homologous gene: combine genes from both parents.
192194
self.nodes[key] = ng1.crossover(ng2)
193195

196+
197+
194198
def mutate(self, config):
195199
"""Mutates this genome."""
196200

@@ -409,6 +413,19 @@ def get_pruned_copy(self, genome_config):
409413
new_genome.connections = used_connection_genes
410414
return new_genome
411415

416+
def compile_optimizer(self, genome_config):
417+
"""Compile this genome into a TorchScript optimizer."""
418+
from graph_builder import DynamicOptimizerModule
419+
420+
if not self.connections:
421+
self.optimizer = None
422+
else:
423+
module = DynamicOptimizerModule(
424+
self, genome_config.input_keys, genome_config.output_keys, self.graph_dict
425+
)
426+
self.optimizer = torch.jit.script(module)
427+
self.optimizer_path = None
428+
412429
def add_node(self, node_type: str, activation, aggregation) -> NodeGene:
413430
if activation is None and aggregation is None:
414431
print("WARNING: node added without any operation")

reproduction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def reproduce(self, config, species, pop_size, generation, task):
103103
child = config.genome_type(cid)
104104
child.configure_crossover(p1, p2, config.genome_config)
105105
child.mutate(config.genome_config)
106+
if hasattr(child, "compile_optimizer"):
107+
child.compile_optimizer(config.genome_config)
106108
new_population[cid] = child
107109
self.ancestors[cid] = (p1_id, p2_id)
108110

0 commit comments

Comments
 (0)