1515 */
1616#include < atomic>
1717#include < chrono>
18+ #include < functional>
1819#include < numeric>
1920#include < thread>
2021
@@ -48,8 +49,6 @@ TuningHarness<Backend>::TuningHarness(
4849 baseMapping_(baseMapping),
4950 inputs_(inputs),
5051 outputs_(outputs),
51- bestTime_(Duration::max()),
52- bestMappingOptions_(baseMapping),
5352 optionsCache_(optionsCache) {}
5453
5554template <typename Backend>
@@ -67,13 +66,6 @@ void TuningHarness<Backend>::stopAfterCurrentIteration() {
6766 stopRequested_ = true ;
6867}
6968
70- template <typename Backend>
71- const typename Backend::MappingOptionsType&
72- TuningHarness<Backend>::bestMappingOptions() const {
73- std::lock_guard<std::mutex> lock (bestTimeMutex_);
74- return bestMappingOptions_;
75- }
76-
7769#define LOG_LINE_BY_LINE (GSTREAM, ISTREAM ) \
7870 for (std::string line; std::getline(ISTREAM, line);) { \
7971 LOG (GSTREAM) << line; \
@@ -180,11 +172,14 @@ void TuningHarness<Backend>::doEvaluate(
180172
181173 std::vector<Duration> runtimes{Duration::max ()};
182174 try {
183- Duration bestTimeSoFar (Duration::max ());
184- {
185- std::lock_guard<std::mutex> lock (bestTimeMutex_);
186- bestTimeSoFar = bestTime_;
187- }
175+ auto vBest = optionsCache_->getTopKEntries (
176+ lang::canonicalTc (tcTree_),
177+ makeTensorInfoVector (inputs),
178+ makeTensorInfoVector (outputs),
179+ Backend::backendString (),
180+ 1 );
181+ Duration bestTimeSoFar =
182+ (vBest.size () > 0 ) ? vBest[0 ].second : Duration::max ();
188183 auto prune = detail::skipExecutionOrWarmup<Backend>(
189184 *pExecutor, outputs, inputs, bestTimeSoFar);
190185 if (prune) {
@@ -234,15 +229,6 @@ void TuningHarness<Backend>::doEvaluate(
234229 Backend::backendString (),
235230 options,
236231 prof);
237-
238- // Save best time under lock
239- {
240- std::lock_guard<std::mutex> lock (bestTimeMutex_);
241- if (prof < bestTime_) {
242- bestTime_ = prof;
243- bestMappingOptions_ = options;
244- }
245- }
246232 } // end while
247233}
248234
@@ -310,7 +296,14 @@ void TuningHarness<Backend>::runOneIteration(
310296 LOG (INFO) << " [TUNER][ITERATION LOG] best option so far:" ;
311297 std::stringstream ssInfo;
312298 typename Backend::MappingOptionsCppPrinter infoPrinter (ssInfo);
313- infoPrinter << bestMappingOptions ();
299+ auto vBest = optionsCache_->getTopKOptions (
300+ lang::canonicalTc (tcTree_),
301+ makeTensorInfoVector (inputs_.begin ()->second ),
302+ makeTensorInfoVector (outputs_.begin ()->second ),
303+ Backend::backendString (),
304+ 1 );
305+ CHECK_GT (vBest.size (), 0 );
306+ infoPrinter << vBest[0 ];
314307 LOG_LINE_BY_LINE (INFO, ssInfo);
315308 }
316309 searchStrategy.updateParameters ();
@@ -426,6 +419,7 @@ Autotuner<Backend, SearchStrategy>::tune(
426419 const std::unordered_map<size_t , std::vector<const DLConstTensor*>>& inputs,
427420 std::unordered_map<size_t , std::vector<const DLTensor*>>& outputs,
428421 const std::vector<typename Backend::MappingOptionsType>& baseMappings,
422+ size_t topK,
429423 const TuningParameterFixer& fixedParams) {
430424 std::map<std::string, lang::TreeRef> tcEntryPointMap (tc::detail::parse (tc));
431425 TC_CHECK_EQ (tcEntryPointMap.count (tcEntryPoint), 1u )
@@ -511,7 +505,12 @@ Autotuner<Backend, SearchStrategy>::tune(
511505 std::rethrow_exception (tuningHarnessThreadEx);
512506 }
513507
514- return {tuningHarness.bestMappingOptions ()};
508+ return optionsCache->getTopKOptions (
509+ lang::canonicalTc (tcEntryPointMap.at (tcEntryPoint)),
510+ makeTensorInfoVector (inputs.begin ()->second ),
511+ makeTensorInfoVector (outputs.begin ()->second ),
512+ Backend::backendString (),
513+ topK);
515514}
516515} // namespace autotune
517516} // namespace tc
0 commit comments