Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit c555112

Browse files
committed
ScheduleTree: introduce virtual "clone" method
The original implementation of schedule trees with "element" classes was relying on a manual macro-based dynamic dispatch to copy schedule tree nodes. Arguably, this is not desirable in the C++ code where the same behavior can be achieved through language mechanisms (virtual functions or CRTP). Since copy constructors cannot be overloaded, introduce an abstract "clone" method to ScheduleTree and override it in all subclasses to produce copies of the given node wrapped in a unique pointer to the base class. Note that children of the given node are not copied to remain consistent with the existing uses. Replace macro-based dispatch with a call to the virtual function. Using language featuers for dispatch is less error-prone and converts run-time errors (e.g., unhandled type in macro-based if/else sequence) into compile-time errors (e.g., attempting to instantiate an abstract class because the abstract function was not implemented). It also localizes the changes to be made to the ScheduleTree API when introducing or removing a new node type.
1 parent c4e612c commit c555112

File tree

3 files changed

+32
-26
lines changed

3 files changed

+32
-26
lines changed

tc/core/polyhedral/schedule_tree.cc

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -122,31 +122,6 @@ vector<ScheduleTree*> ancestorsInSubTree(
122122
}
123123
return res;
124124
}
125-
126-
static std::unique_ptr<ScheduleTree> makeElem(const ScheduleTree& st) {
127-
#define ELEM_MAKE_CASE_CSTR(CLASS) \
128-
else if (st.type_ == CLASS::NodeType) { \
129-
return CLASS::make(static_cast<const CLASS*>(&st)); \
130-
}
131-
132-
if (st.type_ == detail::ScheduleTreeType::None) {
133-
LOG(FATAL) << "Hit Error node!";
134-
}
135-
ELEM_MAKE_CASE_CSTR(ScheduleTreeBand)
136-
ELEM_MAKE_CASE_CSTR(ScheduleTreeContext)
137-
ELEM_MAKE_CASE_CSTR(ScheduleTreeDomain)
138-
ELEM_MAKE_CASE_CSTR(ScheduleTreeExtension)
139-
ELEM_MAKE_CASE_CSTR(ScheduleTreeFilter)
140-
ELEM_MAKE_CASE_CSTR(ScheduleTreeMapping)
141-
ELEM_MAKE_CASE_CSTR(ScheduleTreeSequence)
142-
ELEM_MAKE_CASE_CSTR(ScheduleTreeSet)
143-
ELEM_MAKE_CASE_CSTR(ScheduleTreeThreadSpecificMarker)
144-
145-
#undef ELEM_MAKE_CASE_CSTR
146-
147-
LOG(FATAL) << "NYI: ScheduleTree from type: " << static_cast<int>(st.type_);
148-
return nullptr;
149-
}
150125
} // namespace
151126

152127
////////////////////////////////////////////////////////////////////////////////
@@ -163,7 +138,7 @@ ScheduleTree::ScheduleTree(const ScheduleTree& st)
163138
}
164139

165140
ScheduleTreeUPtr ScheduleTree::makeScheduleTree(const ScheduleTree& tree) {
166-
return makeElem(tree);
141+
return tree.clone();
167142
}
168143

169144
ScheduleTree* ScheduleTree::child(const vector<size_t>& positions) {

tc/core/polyhedral/schedule_tree.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,10 @@ struct ScheduleTree {
465465
// Note that this function does _not_ output the child trees.
466466
virtual std::ostream& write(std::ostream& os) const = 0;
467467

468+
// Clone the current node.
469+
// Note that this function does _not_ clone the child trees.
470+
virtual ScheduleTreeUPtr clone() const = 0;
471+
468472
//
469473
// Data members
470474
//

tc/core/polyhedral/schedule_tree_elem.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ struct ScheduleTreeContext : public ScheduleTree {
5858
}
5959

6060
virtual std::ostream& write(std::ostream& os) const override;
61+
virtual ScheduleTreeUPtr clone() const override {
62+
return make(this);
63+
}
6164

6265
public:
6366
isl::set context_;
@@ -91,6 +94,9 @@ struct ScheduleTreeDomain : public ScheduleTree {
9194
}
9295

9396
virtual std::ostream& write(std::ostream& os) const override;
97+
virtual ScheduleTreeUPtr clone() const override {
98+
return make(this);
99+
}
94100

95101
public:
96102
isl::union_set domain_;
@@ -124,6 +130,9 @@ struct ScheduleTreeExtension : public ScheduleTree {
124130
}
125131

126132
virtual std::ostream& write(std::ostream& os) const override;
133+
virtual ScheduleTreeUPtr clone() const override {
134+
return make(this);
135+
}
127136

128137
public:
129138
isl::union_map extension_;
@@ -157,6 +166,9 @@ struct ScheduleTreeFilter : public ScheduleTree {
157166
std::vector<ScheduleTreeUPtr>&& children = {});
158167

159168
virtual std::ostream& write(std::ostream& os) const override;
169+
virtual ScheduleTreeUPtr clone() const override {
170+
return make(this);
171+
}
160172

161173
public:
162174
isl::union_set filter_;
@@ -195,6 +207,9 @@ struct ScheduleTreeMapping : public ScheduleTree {
195207
std::vector<ScheduleTreeUPtr>&& children = {});
196208

197209
virtual std::ostream& write(std::ostream& os) const override;
210+
virtual ScheduleTreeUPtr clone() const override {
211+
return make(this);
212+
}
198213

199214
public:
200215
// Mapping from identifiers to affine functions on domain elements.
@@ -229,6 +244,9 @@ struct ScheduleTreeSequence : public ScheduleTree {
229244
std::vector<ScheduleTreeUPtr>&& children = {});
230245

231246
virtual std::ostream& write(std::ostream& os) const override;
247+
virtual ScheduleTreeUPtr clone() const override {
248+
return make(this);
249+
}
232250
};
233251

234252
struct ScheduleTreeSet : public ScheduleTree {
@@ -256,6 +274,9 @@ struct ScheduleTreeSet : public ScheduleTree {
256274
std::vector<ScheduleTreeUPtr>&& children = {});
257275

258276
virtual std::ostream& write(std::ostream& os) const override;
277+
virtual ScheduleTreeUPtr clone() const override {
278+
return make(this);
279+
}
259280
};
260281

261282
struct ScheduleTreeBand : public ScheduleTree {
@@ -280,6 +301,9 @@ struct ScheduleTreeBand : public ScheduleTree {
280301
}
281302

282303
virtual std::ostream& write(std::ostream& os) const override;
304+
virtual ScheduleTreeUPtr clone() const override {
305+
return make(this);
306+
}
283307

284308
// Make a schedule node band from partial schedule.
285309
// Replace "mupa" by its greatest integer part to ensure that the
@@ -353,6 +377,9 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree {
353377
std::vector<ScheduleTreeUPtr>&& children = {});
354378

355379
virtual std::ostream& write(std::ostream& os) const override;
380+
virtual ScheduleTreeUPtr clone() const override {
381+
return make(this);
382+
}
356383
};
357384

358385
bool elemEquals(

0 commit comments

Comments
 (0)