@@ -33,11 +33,8 @@ void randomizeParameter(Parameter& param, RNG& rng) {
3333 param.selectOption (paramIndex);
3434}
3535
36- template <typename RNG>
37- void randomizePopulation (
38- GeneticSearch::Population::iterator begin,
39- GeneticSearch::Population::iterator end,
40- RNG& rng) {
36+ template <typename RNG, typename Iterator>
37+ void randomizePopulation (Iterator begin, Iterator end, RNG& rng) {
4138 for (auto candidate = begin; candidate != end; ++candidate) {
4239 auto & conf = (*candidate)->configuration ;
4340 do {
@@ -160,7 +157,8 @@ void dropInvalidConfigurations(GeneticSearch::Population& population) {
160157} // namespace
161158
162159#define VALIDATE () \
163- CHECK_LT (numberElites, maxPopulationSize); \
160+ CHECK_LT (maxPopulationSize, matingPoolSize); \
161+ CHECK_LT (maxPopulationSize, selectionPoolSize); \
164162 CHECK (mutationRate >= 0 and mutationRate <= 100 ) \
165163 << " the mutation rate (" << mutationRate \
166164 << " ) should be in the [0,100] interval" ; \
@@ -188,15 +186,16 @@ GeneticSearch::GeneticSearch(
188186 size_t populationSize,
189187 uint8_t crossOverRate,
190188 uint8_t mutationRate,
191- size_t numberElites)
189+ size_t matingPoolSize,
190+ size_t selectionPoolSize)
192191 : population(),
193192 lastBestConf (confs[0 ]),
194193 numGenerations(numGenerations),
195194 maxPopulationSize(populationSize),
196- matingPoolSize(populationSize * 3 ),
195+ matingPoolSize(matingPoolSize),
196+ selectionPoolSize(selectionPoolSize),
197197 crossOverRate(crossOverRate),
198198 mutationRate(mutationRate),
199- numberElites(numberElites),
200199 rng{std::random_device{}()} {
201200 restoreRngState (rng);
202201 VALIDATE ();
@@ -276,13 +275,6 @@ void GeneticSearch::breed() {
276275 auto matingPool =
277276 stochasticUniversalSampling (computeAccumulatedFitness (population));
278277
279- Population new_population;
280- new_population.reserve (matingPoolSize);
281- for (size_t c = 0 ; c < numberElites; ++c) {
282- new_population.push_back (
283- make_unique<CandidateConfiguration>(population.at (c)->configuration ));
284- }
285-
286278 auto select = [&]() -> TuningConfiguration& {
287279 auto idx = std::uniform_int_distribution<size_t >{
288280 size_t (0 ), matingPool.size () - 1 }(rng);
@@ -298,39 +290,20 @@ void GeneticSearch::breed() {
298290 return dist (rng);
299291 };
300292
301- while (new_population .size () < maxPopulationSize ) {
293+ while (selectionPool .size () < selectionPoolSize ) {
302294 if (shouldCrossOver ()) {
303295 auto parent1 = select ();
304296 auto parent2 = select ();
305297 auto parent3 = select ();
306- new_population .emplace_back (make_unique<CandidateConfiguration>(
298+ selectionPool .emplace_back (make_unique<CandidateConfiguration>(
307299 crossover (parent1, parent2, parent3)));
308300 } else {
309- new_population.emplace_back (
310- make_unique<CandidateConfiguration>(select ()));
301+ selectionPool.emplace_back (make_unique<CandidateConfiguration>(select ()));
311302 }
312303 }
313- population = std::move (new_population);
314304}
315305
316- void GeneticSearch::updateParameters () {
317- dropInvalidConfigurations (population);
318-
319- // Sort population before taking any decision
320- std::sort (
321- population.begin (),
322- population.end (),
323- [](const std::unique_ptr<CandidateConfiguration>& a,
324- const std::unique_ptr<CandidateConfiguration>& b) {
325- checkRuntimeRecorded (a->runtime );
326- checkRuntimeRecorded (b->runtime );
327- return a->runtime < b->runtime ;
328- });
329-
330- // Update failsafe lastBestConf
331- lastBestConf =
332- population.size () > 0 ? population.front ()->configuration : lastBestConf;
333-
306+ bool GeneticSearch::resetPopulationIfNotEnoughCandidates () {
334307 if (population.size () < minCandidatesForBreeding) {
335308 LOG_IF (ERROR, FLAGS_debug_tuner)
336309 << population.size () << " out of " << maxPopulationSize
@@ -341,30 +314,94 @@ void GeneticSearch::updateParameters() {
341314 " --tuner_min_launch_total_threads=1. This is mostly relevant "
342315 " when autotuning a TC operating on small tensors. The next "
343316 " generation will be randomly initialized." ;
344- population. resize ( 0 );
345- for (size_t i = 0 ; i < maxPopulationSize ; ++i) {
346- population .emplace_back (
317+ selectionPool. clear ( );
318+ for (size_t i = 0 ; i < selectionPoolSize ; ++i) {
319+ selectionPool .emplace_back (
347320 make_unique<CandidateConfiguration>(lastBestConf));
348321 }
349322 // Don't lose the first one which was the best from before
350- CHECK_LT (0u , population.size ());
351- randomizePopulation (population.begin () + 1 , population.end (), rng);
352- return ;
323+ randomizePopulation (selectionPool.begin () + 1 , selectionPool.end (), rng);
324+ return true ;
353325 }
326+ return false ;
327+ }
328+
329+ namespace {
330+ void sortByRuntime (GeneticSearch::Population& population) {
331+ std::sort (
332+ population.begin (),
333+ population.end (),
334+ [](const std::unique_ptr<CandidateConfiguration>& a,
335+ const std::unique_ptr<CandidateConfiguration>& b) {
336+ checkRuntimeRecorded (a->runtime );
337+ checkRuntimeRecorded (b->runtime );
338+ return a->runtime < b->runtime ;
339+ });
340+ }
341+ } // namespace
354342
343+ void GeneticSearch::generateSelectionPool () {
344+ dropInvalidConfigurations (population);
345+ sortByRuntime (population);
346+ lastBestConf =
347+ population.size () > 0 ? population.front ()->configuration : lastBestConf;
348+ if (resetPopulationIfNotEnoughCandidates ()) {
349+ return ;
350+ }
355351 breed ();
356- for (size_t i = numberElites; i < population.size (); ++i) {
357- mutate (*population[i], mutationRate, mutateIterations, rng);
352+ selectionPool.clear ();
353+ selectionPool.emplace_back (make_unique<CandidateConfiguration>(lastBestConf));
354+ breed ();
355+ for (size_t i = 1 ; i < selectionPool.size (); ++i) {
356+ mutate (*selectionPool[i], mutationRate, mutateIterations, rng);
357+ }
358+ }
359+
360+ void GeneticSearch::selectSurvivors () {
361+ dropInvalidConfigurations (selectionPool);
362+ sortByRuntime (selectionPool);
363+ population.clear ();
364+ std::transform (
365+ selectionPool.begin (),
366+ selectionPool.begin () + std::min (selectionPool.size (), maxPopulationSize),
367+ std::back_inserter (population),
368+ [](const std::unique_ptr<CandidateConfiguration>& c) {
369+ return make_unique<CandidateConfiguration>(c->configuration );
370+ });
371+
372+ if (selectionPool.size () < maxPopulationSize) {
373+ auto numberMissing = maxPopulationSize - selectionPool.size ();
374+
375+ for (size_t i = 0 ; i < numberMissing; ++i) {
376+ selectionPool.emplace_back (
377+ make_unique<CandidateConfiguration>(lastBestConf));
378+ }
379+ randomizePopulation (
380+ selectionPool.rbegin (), selectionPool.rbegin () + numberMissing, rng);
358381 }
359382}
360383
361384GeneticSearch::Population& GeneticSearch::candidatesOfStep (uint64_t step) {
362- if (step != 0 ) {
363- throw std::invalid_argument (" GeneticSearch has only one step" );
385+ if (step > 1 ) {
386+ throw std::invalid_argument (" GeneticSearch has only 2 steps." );
387+ }
388+ if (step == 0 ) {
389+ return population;
390+ } else {
391+ return selectionPool;
364392 }
365- return population;
366393}
367394
395+ void GeneticSearch::finishStep (uint64_t step) {
396+ if (step > 1 ) {
397+ throw std::invalid_argument (" GeneticSearch has only 2 steps." );
398+ }
399+ if (step == 0 ) {
400+ generateSelectionPool ();
401+ } else {
402+ selectSurvivors ();
403+ }
404+ }
368405} // namespace autotune
369406} // namespace tc
370407
0 commit comments