@@ -55,16 +55,14 @@ class Py_StatefulActionNode final : public StatefulActionNode
5555 }
5656};
5757
58- template <typename T>
59- py::object Py_getInput (const T& node, const std::string& name)
58+ py::object Py_getInput (const TreeNode& node, const std::string& name)
6059{
6160 py::object obj;
6261 node.getInput (name, obj);
6362 return obj;
6463}
6564
66- template <typename T>
67- void Py_setOutput (T& node, const std::string& name, const py::object& value)
65+ void Py_setOutput (TreeNode& node, const std::string& name, const py::object& value)
6866{
6967 node.setOutput (name, value);
7068}
@@ -88,21 +86,49 @@ inline py::object convertFromString(StringView str)
8886 }
8987}
9088
91- PYBIND11_MODULE (btpy_cpp, m )
89+ PortsList extractPortsList ( const py::type& type )
9290{
93- py::class_<PortInfo>(m, " PortInfo" );
94- m.def (" input_port" ,
95- [](const std::string& name) { return InputPort<py::object>(name); });
96- m.def (" output_port" ,
97- [](const std::string& name) { return OutputPort<py::object>(name); });
98-
99- m.def (
100- " ports2" ,
101- [](const py::list& inputs, const py::list& outputs) -> auto {
102- return [](py::type type) -> auto { return type; };
103- },
104- py::kw_only (), py::arg (" inputs" ) = py::none (), py::arg (" outputs" ) = py::none ());
91+ PortsList ports;
92+
93+ const auto input_ports = type.attr (" input_ports" ).cast <py::list>();
94+ for (const auto & name : input_ports)
95+ {
96+ ports.insert (InputPort<py::object>(name.cast <std::string>()));
97+ }
10598
99+ const auto output_ports = type.attr (" output_ports" ).cast <py::list>();
100+ for (const auto & name : output_ports)
101+ {
102+ ports.insert (OutputPort<py::object>(name.cast <std::string>()));
103+ }
104+
105+ return ports;
106+ }
107+
108+ NodeBuilder makeTreeNodeBuilderFn (const py::type& type)
109+ {
110+ return [type](const auto & name, const auto & config) -> auto {
111+ py::object obj = type (name, config);
112+
113+ // TODO: Increment the object's reference count or else it
114+ // will be GC'd at the end of this scope. The downside is
115+ // that, unless we can decrement the ref when the unique_ptr
116+ // is destroyed, then the object will live forever.
117+ obj.inc_ref ();
118+
119+ if (py::isinstance<ActionNodeBase>(obj))
120+ {
121+ return std::unique_ptr<TreeNode>(obj.cast <ActionNodeBase*>());
122+ }
123+ else
124+ {
125+ throw std::runtime_error (" invalid node type of " + name);
126+ }
127+ };
128+ }
129+
130+ PYBIND11_MODULE (btpy_cpp, m)
131+ {
106132 py::class_<BehaviorTreeFactory>(m, " BehaviorTreeFactory" )
107133 .def (py::init ())
108134 .def (" register" ,
@@ -112,45 +138,10 @@ PYBIND11_MODULE(btpy_cpp, m)
112138 TreeNodeManifest manifest;
113139 manifest.type = NodeType::ACTION;
114140 manifest.registration_ID = name;
115- manifest.ports = {} ;
141+ manifest.ports = extractPortsList (type) ;
116142 manifest.description = " " ;
117143
118- const auto input_ports = type.attr (" input_ports" ).cast <py::list>();
119- for (const auto & name : input_ports)
120- {
121- manifest.ports .insert (InputPort<py::object>(name.cast <std::string>()));
122- }
123-
124- const auto output_ports = type.attr (" output_ports" ).cast <py::list>();
125- for (const auto & name : output_ports)
126- {
127- manifest.ports .insert (OutputPort<py::object>(name.cast <std::string>()));
128- }
129-
130- factory.registerBuilder (
131- manifest,
132- [type](const std::string& name,
133- const NodeConfig& config) -> std::unique_ptr<TreeNode> {
134- py::object obj = type (name, config);
135- // TODO: Increment the object's reference count or else it
136- // will be GC'd at the end of this scope. The downside is
137- // that, unless we can decrement the ref when the unique_ptr
138- // is destroyed, then the object will live forever.
139- obj.inc_ref ();
140-
141- if (py::isinstance<Py_SyncActionNode>(obj))
142- {
143- return std::unique_ptr<TreeNode>(obj.cast <Py_SyncActionNode*>());
144- }
145- else if (py::isinstance<Py_StatefulActionNode>(obj))
146- {
147- return std::unique_ptr<TreeNode>(obj.cast <Py_StatefulActionNode*>());
148- }
149- else
150- {
151- throw std::runtime_error (" invalid node type of " + name);
152- }
153- });
144+ factory.registerBuilder (manifest, makeTreeNodeBuilderFn (type));
154145 })
155146 .def (" create_tree_from_text" ,
156147 [](BehaviorTreeFactory& factory, const std::string& text) -> Tree {
@@ -173,16 +164,23 @@ PYBIND11_MODULE(btpy_cpp, m)
173164
174165 py::class_<NodeConfig>(m, " NodeConfig" );
175166
176- py::class_<Py_SyncActionNode>(m, " SyncActionNode" )
167+ // Register the C++ type hierarchy so that we can refer to Python subclasses
168+ // by their superclass ptr types in generic C++ code.
169+ py::class_<TreeNode>(m, " _TreeNode" );
170+ py::class_<ActionNodeBase, TreeNode>(m, " _ActionNodeBase" );
171+ py::class_<SyncActionNode, ActionNodeBase>(m, " _SyncActionNode" );
172+ py::class_<StatefulActionNode, ActionNodeBase>(m, " _StatefulActionNode" );
173+
174+ py::class_<Py_SyncActionNode, SyncActionNode>(m, " SyncActionNode" )
177175 .def (py::init<const std::string&, const NodeConfig&>())
178- .def (" get_input" , &Py_getInput<Py_SyncActionNode> )
179- .def (" set_output" , &Py_setOutput<Py_SyncActionNode> )
176+ .def (" get_input" , &Py_getInput)
177+ .def (" set_output" , &Py_setOutput)
180178 .def (" tick" , &Py_SyncActionNode::tick);
181179
182- py::class_<Py_StatefulActionNode>(m, " StatefulActionNode" )
180+ py::class_<Py_StatefulActionNode, StatefulActionNode >(m, " StatefulActionNode" )
183181 .def (py::init<const std::string&, const NodeConfig&>())
184- .def (" get_input" , &Py_getInput<Py_StatefulActionNode> )
185- .def (" set_output" , &Py_setOutput<Py_StatefulActionNode> )
182+ .def (" get_input" , &Py_getInput)
183+ .def (" set_output" , &Py_setOutput)
186184 .def (" on_start" , &Py_StatefulActionNode::onStart)
187185 .def (" on_running" , &Py_StatefulActionNode::onRunning)
188186 .def (" on_halted" , &Py_StatefulActionNode::onHalted);
0 commit comments