2020#include " tc/core/polyhedral/cuda/mapped_scop.h"
2121#include " tc/core/polyhedral/cuda/mapping_types.h"
2222#include " tc/core/polyhedral/exceptions.h"
23- #include " tc/core/polyhedral/functional.h"
24- #include " tc/core/polyhedral/schedule_tree.h"
25- #include " tc/core/polyhedral/schedule_utils.h"
26- #include " tc/core/polyhedral/scop.h"
2723
2824namespace tc {
2925namespace polyhedral {
3026namespace {
31- // This returns the (inclusive) range of the mapping parameter "mappingId"
32- // within the context "mappingContext".
33- // This range corresponds to the blocks/threads active at the particular
34- // location in the tree where this mapping is active.
35- //
36- // This is used to tighten the kernel to only launch on the necessary amount
37- // of resources.
38- //
39- // When the range is unbounded on the right, we return the maximal positive
40- // range (0, max_size_t). This needs to be intersected with launch bounds to
41- // obtain the proper finite range.
42- // Otherwise, the range is asserted bounded on the left and to lie in the
43- // positive half of the integer axis.
44- std::pair<size_t , size_t > rangeOfMappingParameter (
45- isl::set mappingContext,
46- mapping::MappingId mappingId) {
47- if (!mappingContext.involves_param (mappingId)) {
48- return std::make_pair (0 , std::numeric_limits<size_t >::max ());
49- }
50- auto space = mappingContext.get_space ();
51- isl::aff a (isl::aff::param_on_domain_space (space, mappingId));
52- auto max = mappingContext.max_val (a);
53- if (max.is_nan () || max.is_infty ()) {
54- return std::make_pair (0 , std::numeric_limits<size_t >::max ());
55- }
56- TC_CHECK (max.is_int ()) << max.to_str ();
57- TC_CHECK (max.is_nonneg ()) << max.to_str ();
58- auto min = mappingContext.min_val (a);
59- TC_CHECK (min.is_int ()) << max.to_str ();
60- TC_CHECK (min.is_nonneg ()) << max.to_str ();
61-
62- return std::make_pair (
63- static_cast <size_t >(min.get_num_si ()),
64- static_cast <size_t >(max.get_num_si ()));
65- }
66-
6727/*
68- * Compute the maximal value attained by the mapping parameter "id".
28+ * Return the mapping to MappingTypeId, i.e, either the mapping to blocks or
29+ * the mapping to threads.
6930 */
70- template <typename MappingIdType>
71- size_t maxValue (const Scop& scop, const MappingIdType& id) {
72- using namespace polyhedral ::detail;
73-
74- auto root = scop.scheduleRoot ();
75- auto params = scop.context ();
76- size_t sizetMax = std::numeric_limits<size_t >::max ();
77- size_t max = 0 ;
78- size_t min = sizetMax;
79- auto filters = root->collect (root, ScheduleTreeType::Mapping);
80- filters = functional::Filter (isMappingTo<MappingIdType>, filters);
81- for (auto p : filters) {
82- auto mappingNode = p->as <ScheduleTreeMapping>();
83- auto active = activeDomainPoints (root, p).intersect_params (params);
84- active = active.intersect (mappingNode->filter_ );
85- auto range = rangeOfMappingParameter (active.params (), id);
86- min = std::min (min, range.first );
87- max = std::max (max, range.second );
88- }
89- TC_CHECK (max < sizetMax) << " missing mapping to " << id << " \n " << *root;
90- TC_CHECK (min < sizetMax) << " missing mapping to " << id << " type\n " << *root;
91- // Inclusive range needs + 1 to translate to sizes
92- return max + 1 ;
31+ template <typename MappingTypeId>
32+ static isl::multi_union_pw_aff mappingSchedule (const MappedScop& mscop);
33+ template <>
34+ isl::multi_union_pw_aff mappingSchedule<mapping::BlockId>(
35+ const MappedScop& mscop) {
36+ return mscop.blockMappingSchedule (mscop.schedule ());
37+ }
38+ template <>
39+ isl::multi_union_pw_aff mappingSchedule<mapping::ThreadId>(
40+ const MappedScop& mscop) {
41+ return mscop.threadMappingSchedule (mscop.schedule ());
9342}
9443
9544/*
@@ -100,8 +49,17 @@ template <typename MappingIdType, typename Size>
10049Size launchBounds (const MappedScop& mscop, Size size) {
10150 Size tightened;
10251
52+ auto params = mscop.scop ().context ();
53+ auto mapping = mappingSchedule<MappingIdType>(mscop);
54+ mapping = mapping.intersect_params (params);
55+ auto max = mapping.max_multi_val ();
56+
10357 for (size_t i = 0 ; i < size.view .size (); ++i) {
104- tightened.view [i] = maxValue (mscop.scop (), MappingIdType::makeId (i));
58+ auto maxVal = max.get_val (i);
59+ TC_CHECK (maxVal.is_int ()) << maxVal.to_str ();
60+ TC_CHECK (maxVal.is_nonneg ()) << maxVal.to_str ();
61+ // Inclusive range needs + 1 to translate to sizes
62+ tightened.view [i] = maxVal.get_num_si () + 1 ;
10563 }
10664
10765 return tightened;
0 commit comments