@@ -259,102 +259,6 @@ void emitReductionOpName(const Halide::Expr& e, const CodegenContext& context) {
259259 }
260260}
261261
262- namespace {
263- // Compute the range of parameter values in a given set. Both sides of the
264- // range are inclusive.
265- std::pair<isl::val, isl::val> computeParamRange (isl::set domain, int pos) {
266- // Coerce the set to the shape [N] -> {[i]: only N here }
267- domain = domain.params ().from_params ();
268- domain = domain.project_out (isl::dim_type::param, 0 , pos);
269- domain = domain.project_out (
270- isl::dim_type::param, 1 , domain.dim (isl::dim_type::param) - 1 );
271- domain = domain.insert_dims (isl::dim_type::set, 0 , 1 );
272-
273- // Connect parameter to a set dimension [N] -> {[i]: i = N and ...}
274- auto lspace = isl::local_space (domain.get_space ());
275- auto paramAff = isl::aff (lspace, isl::dim_type::param, 0 );
276- auto varAff = isl::aff (lspace, isl::dim_type::set, 0 );
277- domain = domain & (isl::aff_set (paramAff) == varAff);
278-
279- // Remove the remaining parameter to move its constraints to the set dimension
280- domain = domain.project_out (isl::dim_type::param, 0 , 1 );
281-
282- // Get min and max.
283- auto lower = domain.dim_min (0 );
284- auto upper = domain.dim_max (0 );
285-
286- // Compute the range
287- CHECK (lower.is_cst () && upper.is_cst ())
288- << " expected constant lower and upper bounds" ;
289-
290- // Without parameters at all, we must have a single piece in the bound PA.
291- auto lowerPA = isl::PA (lower);
292- auto upperPA = isl::PA (upper);
293- CHECK (lowerPA.size () == 1 && upperPA.size () == 1 );
294-
295- return std::make_pair (
296- lowerPA[0 ].second .get_constant_val (),
297- upperPA[0 ].second .get_constant_val ());
298- }
299-
300- // Given the iteratorMaps, whose domain was affected by the mapping filters, in
301- // the provided context, compute the range of thread mapping parameters. If
302- // the statement is not mapped to some threads, they will not appear in the
303- // result.
304- std::unordered_map<isl::id, long , isl::IslIdIslHash> activeThreadsInBlock (
305- const CodegenStatementContext& context) {
306- auto iterMap = context.iteratorMap ();
307- auto dom =
308- iterMap.domain ()
309- .intersect_params (context.mappedScop .scop ().globalParameterContext )
310- .params ()
311- .from_params ();
312-
313- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
314- std::vector<isl::id> threadIds{TX, TY, TZ};
315- std::unordered_map<isl::id, long , isl::IslIdIslHash> activeThreads;
316-
317- for (auto id : threadIds) {
318- int pos = dom.find_dim_by_id (isl::dim_type::param, id);
319- if (pos < 0 ) {
320- continue ;
321- }
322- auto range = computeParamRange (dom, pos);
323- CHECK_EQ (range.first .get_den_si (), 1 ) << " fractional parameters?" ;
324- CHECK_EQ (range.second .get_den_si (), 1 ) << " fractional parameters?" ;
325- CHECK_EQ (range.first .get_num_si (), 0 )
326- << " NYI: active threads starting not from 0" ;
327-
328- activeThreads[id] =
329- range.second .get_num_si () - range.first .get_num_si () + 1 ;
330- }
331- return activeThreads;
332- }
333-
334- // Given the iteratorMaps, whose domain was affected by the mapping filters, in
335- // the provided context, compute the range of thread mapping parameters. If
336- // the statement is not mapped to some threads, they will _still appear_ in the
337- // result with the range 1.
338- std::array<long , 3 > activeThreadsInBlockWithDefaults (
339- const CodegenStatementContext& context) {
340- auto active = activeThreadsInBlock (context);
341- std::array<long , 3 > result;
342-
343- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
344- std::vector<isl::id> threadIds{TX, TY, TZ};
345- int i = 0 ;
346- for (auto id : threadIds) {
347- if (active.count (id) != 1 ) {
348- result[i] = MappingId::unmapped;
349- } else {
350- result[i] = active[id];
351- }
352- ++i;
353- }
354- return result;
355- }
356- } // namespace
357-
358262// Emit a cross-thread tree reduce.
359263// For now this is only expected to work with threadIdx.x.
360264void emitTreeSyncCall (
@@ -373,25 +277,16 @@ void emitTreeSyncCall(
373277 std::array<size_t , 3 > dims = {TX.mappingSize (context.mappedScop .numThreads ),
374278 TY.mappingSize (context.mappedScop .numThreads ),
375279 TZ.mappingSize (context.mappedScop .numThreads )};
376- std::array<long , 3 > active = activeThreadsInBlockWithDefaults (context);
377-
378- for (int i = 0 ; i < 3 ; ++i) {
379- if (active[i] < dims[i]) {
380- LOG (INFO) << " Reduction statement " << updateId << " mapped to "
381- << dims[i] << " threads along dim: " << i << " but only "
382- << active[i] << " are non-empty" ;
383- }
384- }
385280
386281 context.ss << tc::code::cuda::kCUBReductionName ;
387282
388283 // Template mapping dimension
389284 context.ss << " <" ;
390- context.ss << active [0 ];
285+ context.ss << dims [0 ];
391286 context.ss << " ," ;
392- context.ss << active [1 ];
287+ context.ss << dims [1 ];
393288 context.ss << " ," ;
394- context.ss << active [2 ];
289+ context.ss << dims [2 ];
395290 context.ss << " ," ;
396291 emitReductionOpName (provide->values [0 ], context);
397292 context.ss << " >(" ;
0 commit comments