Skip to content

Commit ccacf87

Browse files
committed
Fix compile_optimizer to rebuild graph
1 parent 5ddb3c3 commit ccacf87

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

genome.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute
1414
from computation_graphs.functions.activation import *
1515
from computation_graphs.functions.aggregation import *
16-
from genes import ConnectionGene, NodeGene
16+
from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX
1717

1818

1919
class OptimizerGenomeConfig(object):
@@ -415,15 +415,38 @@ def get_pruned_copy(self, genome_config):
415415

416416
def compile_optimizer(self, genome_config):
417417
"""Compile this genome into a TorchScript optimizer."""
418-
from graph_builder import DynamicOptimizerModule
419-
420-
if not self.connections:
421-
self.optimizer = None
422-
else:
423-
module = DynamicOptimizerModule(
424-
self, genome_config.input_keys, genome_config.output_keys, self.graph_dict
425-
)
426-
self.optimizer = torch.jit.script(module)
418+
from graph_builder import rebuild_and_script
419+
420+
if self.graph_dict is None:
421+
node_ids = sorted(self.nodes.keys())
422+
node_types = []
423+
node_attributes = []
424+
for nid in node_ids:
425+
node = self.nodes[nid]
426+
idx = NODE_TYPE_TO_INDEX.get(node.node_type)
427+
if idx is None:
428+
raise KeyError(f"Unknown node_type {node.node_type!r}")
429+
node_types.append(idx)
430+
node_attributes.append(node.dynamic_attributes)
431+
node_types = torch.tensor(node_types, dtype=torch.long)
432+
433+
edges = []
434+
for (src, dst), conn in self.connections.items():
435+
if conn.enabled and src in node_ids and dst in node_ids:
436+
local_src = node_ids.index(src)
437+
local_dst = node_ids.index(dst)
438+
edges.append([local_src, local_dst])
439+
if edges:
440+
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
441+
else:
442+
edge_index = torch.empty((2, 0), dtype=torch.long)
443+
self.graph_dict = {
444+
"node_types": node_types,
445+
"edge_index": edge_index,
446+
"node_attributes": node_attributes,
447+
}
448+
449+
self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key)
427450
self.optimizer_path = None
428451

429452
def add_node(self, node_type: str, activation, aggregation) -> NodeGene:

0 commit comments

Comments
 (0)