@@ -356,26 +356,61 @@ size_t largestDim(const std::vector<const DLConstTensor*>& inputs) {
356356 return (*maxElement)->ndim ;
357357}
358358
359- // Creates well-chosen parameter sizes to match the input shapes.
360- void setupTuningParameters (
359+ // Creates well-chosen generic parameter sizes to match the input shapes.
360+ template <typename MappingOptionsType>
361+ inline std::pair<TuningConfiguration, std::vector<size_t >>
362+ setupGenericTuningParametersAndGetRange (
361363 const std::vector<const DLConstTensor*>& inputs,
362- TuningConfiguration& configuration ) {
364+ const std::vector<MappingOptionsType>& baseMappings ) {
363365 TC_CHECK_GE (inputs.size (), 1u );
364366 auto range = inputDivisorsAndPowers2 (inputs);
365367 // 0 is a valid tiling annotation and signals no tiling of that dimension
366368 // 0 is not a valid block / grid annotation
367369 auto nTilesDim = largestDim (inputs) + 1 ;
368370 auto tileRange = range;
369371 tileRange.push_back (0 );
372+
373+ TuningConfiguration configuration;
370374 configuration.tilingParams .setRange (nTilesDim, tileRange);
371- configuration.blockParams .setRange (range, " b" );
372- configuration.gridParams .setRange (range, " g" );
373375 configuration.unrollFactor =
374376 RangeParameter (powers2 (FLAGS_tuner_max_unroll_size), " unroll" );
377+
378+ return {configuration, range};
379+ }
380+
381+ // Creates well-chosen parameter sizes to match the input shapes.
382+ inline TuningConfiguration setupTuningParameters (
383+ const std::vector<const DLConstTensor*>& inputs,
384+ const std::vector<CudaMappingOptions>& baseMappings) {
385+ std::vector<size_t > range;
386+ TuningConfiguration configuration;
387+ std::tie (configuration, range) =
388+ setupGenericTuningParametersAndGetRange (inputs, baseMappings);
389+ auto blockRange = range;
390+ auto gridRange = range;
391+
392+ for (const auto & baseMapping : baseMappings) {
393+ blockRange =
394+ mergeVectors (std::move (blockRange), baseMapping.block .extractVector ());
395+ gridRange =
396+ mergeVectors (std::move (gridRange), baseMapping.grid .extractVector ());
397+ }
398+
399+ configuration.blockParams .setRange (blockRange, " b" );
400+ configuration.gridParams .setRange (gridRange, " g" );
375401 configuration.privateDepth =
376402 RangeParameter ({0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 }, " pdepth" );
377403 configuration.sharedDepth =
378404 RangeParameter ({0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 }, " sdepth" );
405+
406+ return configuration;
407+ }
408+
409+ // Creates well-chosen parameter sizes to match the input shapes.
410+ inline TuningConfiguration setupTuningParameters (
411+ const std::vector<const DLConstTensor*>& inputs,
412+ const std::vector<CpuMappingOptions>& baseMappings) {
413+ return setupGenericTuningParametersAndGetRange (inputs, baseMappings).first ;
379414}
380415} // namespace
381416
@@ -397,9 +432,9 @@ Autotuner<Backend, SearchStrategy>::tune(
397432 << " Error looking up " << tcEntryPoint;
398433
399434 // Initialize a model configuration
400- TuningConfiguration modelConfiguration;
401435 TC_CHECK_GE (inputs.size (), 1u );
402- setupTuningParameters (inputs.begin ()->second , modelConfiguration);
436+ auto modelConfiguration =
437+ setupTuningParameters (inputs.begin ()->second , baseMappings);
403438 modelConfiguration.fixParameters (fixedParams);
404439
405440 // Create initial configs based on options + model configuration
0 commit comments