1717
1818#include < algorithm>
1919#include < numeric>
20+ #include < tuple>
2021#include < unordered_set>
2122
2223#include " tc/core/constants.h"
@@ -238,7 +239,20 @@ isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) {
238239 return context;
239240}
240241
241- isl::map extractAccess (
242+ // Extract a tagged affine access relation from Halide IR.
243+ // The relation is tagged with a unique identifier, i.e. it lives in the space
244+ // [D[...] -> __tc_ref_#[]] -> A[]
245+ // where # is a unique sequential number, D is the statement identifier
246+ // extracted from "domain" and A is the tensor identifier constructed from
247+ // "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
248+ // which a particular reference # appeared.
249+ // Returns the access relation and a flag indicating whether this relation is
250+ // exact or not. The relation is overapproximated (that is, not exact) if it
251+ // represents a non-affine access, for example, an access with indirection such
252+ // as O(Index(i)) = 42. In such overapproximated access relation, dimensions
253+ // that correspond to affine subscripts are still exact while those that
254+ // correspond to non-affine subscripts are not constrained.
255+ std::pair<isl::map, bool > extractAccess (
242256 isl::set domain,
243257 const IRNode* op,
244258 const std::string& tensor,
@@ -267,6 +281,7 @@ isl::map extractAccess(
267281 isl::map map =
268282 isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
269283
284+ bool exact = true ;
270285 for (size_t i = 0 ; i < args.size (); i++) {
271286 // Then add one equality constraint per dimension to encode the
272287 // point in the allocation actually read/written for each point in
@@ -278,47 +293,64 @@ isl::map extractAccess(
278293 isl::pw_aff (isl::local_space (rangeSpace), isl::dim_type::set, i);
279294 // ... equals the coordinate accessed as a function of the domain.
280295 auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace, args[i]);
281- if (!domainPoint.is_null ()) {
296+ if (!domainPoint) {
297+ exact = false ;
298+ } else {
282299 map = map.intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
283300 }
284301 }
285302
286- return map;
303+ return std::make_pair ( map, exact) ;
287304}
288305
289- std::pair< isl::union_map, isl::union_map>
306+ std::tuple<isl::union_map, isl::union_map, isl::union_map>
290307extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
291308 class FindAccesses : public IRGraphVisitor {
292309 using IRGraphVisitor::visit;
293310
294311 void visit (const Call* op) override {
295312 IRGraphVisitor::visit (op);
296313 if (op->call_type == Call::Halide || op->call_type == Call::Image) {
297- reads = reads.unite (
298- extractAccess (domain, op, op->name , op->args , accesses));
314+ // Read relations can be safely overapproximated.
315+ isl::map read;
316+ std::tie (read, std::ignore) =
317+ extractAccess (domain, op, op->name , op->args , accesses);
318+ reads = reads.unite (read);
299319 }
300320 }
301321
302322 void visit (const Provide* op) override {
303323 IRGraphVisitor::visit (op);
304- writes =
305- writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
324+
325+ // If the write access relation is not exact, we consider that any
326+ // element _may_ be written by the statement. If it is exact, then we
327+ // can guarantee that all the elements specified by the relation _must_
328+ // be written and any previously stored value will be killed.
329+ isl::map write;
330+ bool exact;
331+ std::tie (write, exact) =
332+ extractAccess (domain, op, op->name , op->args , accesses);
333+ if (exact) {
334+ mustWrites = mustWrites.unite (write);
335+ }
336+ mayWrites = mayWrites.unite (write);
306337 }
307338
308339 const isl::set& domain;
309340 AccessMap* accesses;
310341
311342 public:
312- isl::union_map reads, writes ;
343+ isl::union_map reads, mayWrites, mustWrites ;
313344
314345 FindAccesses (const isl::set& domain, AccessMap* accesses)
315346 : domain(domain),
316347 accesses (accesses),
317348 reads(isl::union_map::empty(domain.get_space())),
318- writes(isl::union_map::empty(domain.get_space())) {}
349+ mayWrites(isl::union_map::empty(domain.get_space())),
350+ mustWrites(isl::union_map::empty(domain.get_space())) {}
319351 } finder(domain, accesses);
320352 s.accept(&finder);
321- return { finder.reads , finder.writes } ;
353+ return std::make_tuple( finder.reads, finder.mayWrites, finder.mustWrites) ;
322354}
323355
324356/*
@@ -343,7 +375,8 @@ isl::schedule makeScheduleTreeHelper(
343375 isl::set set,
344376 std::vector<std::string>& outer,
345377 isl::union_map* reads,
346- isl::union_map* writes,
378+ isl::union_map* mayWrites,
379+ isl::union_map* mustWrites,
347380 AccessMap* accesses,
348381 StatementMap* statements,
349382 IteratorMap* iterators) {
@@ -389,7 +422,8 @@ isl::schedule makeScheduleTreeHelper(
389422 set,
390423 outerNext,
391424 reads,
392- writes,
425+ mayWrites,
426+ mustWrites,
393427 accesses,
394428 statements,
395429 iterators);
@@ -422,7 +456,15 @@ isl::schedule makeScheduleTreeHelper(
422456 std::vector<isl::schedule> schedules;
423457 for (Stmt s : stmts) {
424458 schedules.push_back (makeScheduleTreeHelper (
425- s, set, outer, reads, writes, accesses, statements, iterators));
459+ s,
460+ set,
461+ outer,
462+ reads,
463+ mayWrites,
464+ mustWrites,
465+ accesses,
466+ statements,
467+ iterators));
426468 }
427469 schedule = schedules[0 ].sequence (schedules[1 ]);
428470
@@ -437,23 +479,25 @@ isl::schedule makeScheduleTreeHelper(
437479 isl::set domain = set.set_tuple_id (id);
438480 schedule = isl::schedule::from_domain (domain);
439481
440- isl::union_map newReads, newWrites ;
441- std::tie (newReads, newWrites ) =
482+ isl::union_map newReads, newMayWrites, newMustWrites ;
483+ std::tie (newReads, newMayWrites, newMustWrites ) =
442484 halide2isl::extractAccesses (domain, op, accesses);
443485
444486 *reads = reads->unite (newReads);
445- *writes = writes->unite (newWrites);
487+ *mayWrites = mayWrites->unite (newMayWrites);
488+ *mustWrites = mustWrites->unite (newMustWrites);
446489
447490 } else {
448491 LOG (FATAL) << " Unhandled Halide stmt: " << s;
449492 }
450493 return schedule;
451- };
494+ }
452495
453496ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
454497 ScheduleTreeAndAccesses result;
455498
456- result.writes = result.reads = isl::union_map::empty (paramSpace);
499+ result.mayWrites = result.mustWrites = result.reads =
500+ isl::union_map::empty (paramSpace);
457501
458502 // Walk the IR building a schedule tree
459503 std::vector<std::string> outer;
@@ -462,7 +506,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
462506 isl::set::universe (paramSpace),
463507 outer,
464508 &result.reads ,
465- &result.writes ,
509+ &result.mayWrites ,
510+ &result.mustWrites ,
466511 &result.accesses ,
467512 &result.statements ,
468513 &result.iterators );
0 commit comments