Skip to content

Commit 279fe0a

Browse files
authored
Merge pull request #26 from TimeDelta/codex/fix-typeerror-on-second-generation
Fix crossover offspring optimizer compilation
2 parents ea294a7 + ccacf87 commit 279fe0a

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

genome.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from random import choice, random, shuffle
55
from typing import Dict, List, Tuple
66

7+
import torch
8+
79
from neat.aggregations import AggregationFunctionSet
810
from neat.config import ConfigParameter, write_pretty_params
911
from neat.graphs import creates_cycle, required_for_output
1012

1113
from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute
1214
from computation_graphs.functions.activation import *
1315
from computation_graphs.functions.aggregation import *
14-
from genes import ConnectionGene, NodeGene
16+
from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX
1517

1618

1719
class OptimizerGenomeConfig(object):
@@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config):
191193
# Homologous gene: combine genes from both parents.
192194
self.nodes[key] = ng1.crossover(ng2)
193195

196+
197+
194198
def mutate(self, config):
195199
"""Mutates this genome."""
196200

@@ -409,6 +413,42 @@ def get_pruned_copy(self, genome_config):
409413
new_genome.connections = used_connection_genes
410414
return new_genome
411415

416+
def compile_optimizer(self, genome_config):
417+
"""Compile this genome into a TorchScript optimizer."""
418+
from graph_builder import rebuild_and_script
419+
420+
if self.graph_dict is None:
421+
node_ids = sorted(self.nodes.keys())
422+
node_types = []
423+
node_attributes = []
424+
for nid in node_ids:
425+
node = self.nodes[nid]
426+
idx = NODE_TYPE_TO_INDEX.get(node.node_type)
427+
if idx is None:
428+
raise KeyError(f"Unknown node_type {node.node_type!r}")
429+
node_types.append(idx)
430+
node_attributes.append(node.dynamic_attributes)
431+
node_types = torch.tensor(node_types, dtype=torch.long)
432+
433+
edges = []
434+
for (src, dst), conn in self.connections.items():
435+
if conn.enabled and src in node_ids and dst in node_ids:
436+
local_src = node_ids.index(src)
437+
local_dst = node_ids.index(dst)
438+
edges.append([local_src, local_dst])
439+
if edges:
440+
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
441+
else:
442+
edge_index = torch.empty((2, 0), dtype=torch.long)
443+
self.graph_dict = {
444+
"node_types": node_types,
445+
"edge_index": edge_index,
446+
"node_attributes": node_attributes,
447+
}
448+
449+
self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key)
450+
self.optimizer_path = None
451+
412452
def add_node(self, node_type: str, activation, aggregation) -> NodeGene:
413453
if activation is None and aggregation is None:
414454
print("WARNING: node added without any operation")

reproduction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def reproduce(self, config, species, pop_size, generation, task):
103103
child = config.genome_type(cid)
104104
child.configure_crossover(p1, p2, config.genome_config)
105105
child.mutate(config.genome_config)
106+
if hasattr(child, "compile_optimizer"):
107+
child.compile_optimizer(config.genome_config)
106108
new_population[cid] = child
107109
self.ancestors[cid] = (p1_id, p2_id)
108110

0 commit comments

Comments
 (0)