@@ -134,56 +134,47 @@ isl::id statementId(const Scop& scop, const Halide::Internal::Stmt& stmt) {
134134
135135} // namespace
136136
137- std::pair<isl::union_set, std::vector< isl::id> > reductionInitsUpdates (
137+ std::pair<isl::union_set, isl::union_set > reductionInitsUpdates (
138138 isl::union_set domain,
139139 const Scop& scop) {
140140 auto initUnion = isl::union_set::empty (domain.get_space ());
141- std::vector<isl::id> update;
141+ auto update = initUnion ;
142142 std::unordered_set<isl::id, isl::IslIdIslHash> init;
143143 std::vector<isl::set> nonUpdate;
144- // First collect all the update statement identifiers ,
144+ // First collect all the update statements ,
145145 // the corresponding init statement and all non-update statements.
146146 domain.foreach_set ([&init, &update, &nonUpdate, &scop](isl::set set) {
147147 auto setId = set.get_tuple_id ();
148148 Halide::Internal::Stmt initStmt;
149149 std::vector<size_t > reductionDims;
150150 if (isReductionUpdateId (setId, scop, initStmt, reductionDims)) {
151- update. emplace_back (setId );
151+ update = update. unite (set );
152152 init.emplace (statementId (scop, initStmt));
153153 } else {
154154 nonUpdate.emplace_back (set);
155155 }
156156 });
157157 // Then check if all the non-update statements are init statements
158158 // that correspond to the update statements found.
159- // If not, return an empty list of update statement identifiers .
159+ // If not, return an empty list of update statements .
160160 for (auto set : nonUpdate) {
161161 if (init.count (set.get_tuple_id ()) != 1 ) {
162- return std::pair<isl::union_set, std::vector< isl::id> >(
163- initUnion, std::vector<isl::id>( ));
162+ return std::pair<isl::union_set, isl::union_set >(
163+ initUnion, isl::union_set::empty (domain. get_space () ));
164164 }
165165 initUnion = initUnion.unite (set);
166166 }
167- return std::pair<isl::union_set, std::vector< isl::id> >(initUnion, update);
167+ return std::pair<isl::union_set, isl::union_set >(initUnion, update);
168168}
169169
170- int findFirstReductionDim (isl::multi_union_pw_aff islMupa, const Scop& scop) {
171- auto mupa = isl::MUPA (islMupa);
172- int reductionDim = -1 ;
173- int currentDim = 0 ;
174- for (auto const & upa : mupa) {
175- for (auto const & pa : upa) {
176- if (isAlmostIdentityReduction (pa.pa , scop)) {
177- reductionDim = currentDim;
178- break ;
179- }
180- }
181- if (reductionDim != -1 ) {
182- break ;
183- }
184- ++currentDim;
185- }
186- return reductionDim;
170+ bool isReductionMember (
171+ isl::union_pw_aff member,
172+ isl::union_set domain,
173+ const Scop& scop) {
174+ return domain.every_set ([member, &scop](isl::set set) {
175+ auto pa = member.extract_on_domain (set.get_space ());
176+ return isAlmostIdentityReduction (pa, scop);
177+ });
187178}
188179
189180} // namespace polyhedral
0 commit comments