|
| 1 | +/** |
| 2 | + * Copyright (C) 2025-present MongoDB, Inc. |
| 3 | + * |
| 4 | + * This program is free software: you can redistribute it and/or modify |
| 5 | + * it under the terms of the Server Side Public License, version 1, |
| 6 | + * as published by MongoDB, Inc. |
| 7 | + * |
| 8 | + * This program is distributed in the hope that it will be useful, |
| 9 | + * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 10 | + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 11 | + * Server Side Public License for more details. |
| 12 | + * |
| 13 | + * You should have received a copy of the Server Side Public License |
| 14 | + * along with this program. If not, see |
| 15 | + * <http://www.mongodb.com/licensing/server-side-public-license>. |
| 16 | + * |
| 17 | + * As a special exception, the copyright holders give permission to link the |
| 18 | + * code of portions of this program with the OpenSSL library under certain |
| 19 | + * conditions as described in each individual source file and distribute |
| 20 | + * linked combinations including the program with the OpenSSL library. You |
| 21 | + * must comply with the Server Side Public License in all respects for |
| 22 | + * all of the code used other than as permitted herein. If you modify file(s) |
| 23 | + * with this exception, you may extend this exception to your version of the |
| 24 | + * file(s), but you are not obligated to do so. If you do not wish to do so, |
| 25 | + * delete this exception statement from your version. If you delete this |
| 26 | + * exception statement from all source files in the program, then also delete |
| 27 | + * it in the license file. |
| 28 | + */ |
| 29 | + |
| 30 | +#include "mongo/util/moving_average.h" |
| 31 | + |
| 32 | +#include "mongo/unittest/barrier.h" |
| 33 | +#include "mongo/unittest/join_thread.h" |
| 34 | +#include "mongo/unittest/unittest.h" |
| 35 | + |
| 36 | +#include <algorithm> |
| 37 | +#include <cmath> |
| 38 | +#include <numbers> |
| 39 | +#include <thread> |
| 40 | +#include <vector> |
| 41 | + |
| 42 | +namespace mongo { |
| 43 | +namespace { |
| 44 | + |
| 45 | +template <typename Func> |
| 46 | +void runForAlphas(Func&& func) { |
| 47 | + const double alphas[] = {0.05, 0.1, 0.2, 0.4, 0.8, 0.9, 0.95}; |
| 48 | + for (const double alpha : alphas) { |
| 49 | + MovingAverage avg{alpha}; |
| 50 | + func(avg); |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +// Verify that `MovingAverage::get()` returns `boost::none` if `addSample` has |
| 55 | +// never been called on the object. |
| 56 | +TEST(MovingAverageTest, StartsWithNone) { |
| 57 | + runForAlphas([](auto& avg) { ASSERT_EQ(avg.get(), boost::none) << " alpha=" << avg.alpha(); }); |
| 58 | +} |
| 59 | + |
| 60 | +// Verify that if `MovingAverage::addSample` has been called on the object only |
| 61 | +// once, then `get()` returns the value of that sample. |
| 62 | +TEST(MovingAverageTest, FirstSampleIsAverage) { |
| 63 | + const double first = -1.337; |
| 64 | + runForAlphas([=](auto& avg) { |
| 65 | + avg.addSample(first); |
| 66 | + ASSERT_EQ(avg.get(), first) << "alpha=" << avg.alpha(); |
| 67 | + }); |
| 68 | +} |
| 69 | + |
| 70 | +// Verify that adding a sample to an exponential moving average results in a |
| 71 | +// new average that is between the previous average and the sample. |
| 72 | +// Sample from the sine function, for example. |
| 73 | +TEST(MovingAverageTest, AverageMovesTowardsSamples) { |
| 74 | + runForAlphas([](auto& avg) { |
| 75 | + double theta = 0; |
| 76 | + double sample = std::sin(theta); |
| 77 | + double oldAvg = avg.addSample(sample); |
| 78 | + const double delta = 0.1; |
| 79 | + theta += delta; |
| 80 | + do { |
| 81 | + sample = std::sin(theta); |
| 82 | + const double newAvg = avg.addSample(sample); |
| 83 | + const auto [below, above] = std::minmax(oldAvg, sample); |
| 84 | + |
| 85 | + ASSERT_LTE(below, newAvg) |
| 86 | + << "theta=" << theta << " sin(theta)=" << sample << " alpha=" << avg.alpha(); |
| 87 | + ASSERT_GTE(above, newAvg) |
| 88 | + << "theta=" << theta << " sin(theta)=" << sample << " alpha=" << avg.alpha(); |
| 89 | + |
| 90 | + oldAvg = newAvg; |
| 91 | + theta += delta; |
| 92 | + } while (theta < 2 * std::numbers::pi); |
| 93 | + }); |
| 94 | +} |
| 95 | + |
| 96 | +// Verify that `get()` returns the most recent average. |
| 97 | +TEST(MovingAverageTest, GetIsConsistentWithAddSampleAndIsIdempotent) { |
| 98 | + runForAlphas([](auto& avg) { |
| 99 | + // arbitrary history of samples |
| 100 | + const double warmup[] = {9898344, -309409, 2.7e-12, 42}; |
| 101 | + |
| 102 | + double mostRecentAvg; |
| 103 | + for (const double sample : warmup) { |
| 104 | + mostRecentAvg = avg.addSample(sample); |
| 105 | + } |
| 106 | + |
| 107 | + ASSERT_EQ(avg.get(), mostRecentAvg) << "alpha=" << avg.alpha(); |
| 108 | + ASSERT_EQ(avg.get(), mostRecentAvg) << "alpha=" << avg.alpha(); |
| 109 | + }); |
| 110 | +} |
| 111 | + |
| 112 | +// Verify that two or more threads can concurrently call any combination of |
| 113 | +// `get()` and `addSample(...)` without upsetting code sanitizers like |
| 114 | +// ThreadSanitizer (tsan), AddressSansitizer (asan), and |
| 115 | +// UndefinedBehaviorSanitizer (ubsan). This test is only relevant when the test |
| 116 | +// driver is built with sanitizers (e.g. `--config=dbg_tsan` or |
| 117 | +// `--config=dbg_aubsan`). |
| 118 | +TEST(MovingAverageTest, ThreadSafe) { |
| 119 | + // At most as many threads as logical cores, or two threads if we don't |
| 120 | + // know the core count. |
| 121 | + const unsigned maxThreads = std::max(2u, std::thread::hardware_concurrency()); |
| 122 | + |
| 123 | + for (unsigned nThreads = 2; nThreads <= maxThreads; ++nThreads) { |
| 124 | + runForAlphas([=](auto& avg) { |
| 125 | + unittest::Barrier startingLine{nThreads}; |
| 126 | + std::vector<unittest::JoinThread> threads; |
| 127 | + for (unsigned id = 0; id < nThreads; ++id) { |
| 128 | + threads.emplace_back([&, id]() { |
| 129 | + // Wait for the other threads to spawn. |
| 130 | + startingLine.countDownAndWait(); |
| 131 | + |
| 132 | + // Bang on `avg` for a while. |
| 133 | + double scratch = id; |
| 134 | + for (int i = 0; i < 1'000; ++i) { |
| 135 | + avg.addSample(scratch); |
| 136 | + const auto got = avg.get(); |
| 137 | + ASSERT_NE(got, boost::none); |
| 138 | + scratch = *got; |
| 139 | + } |
| 140 | + }); |
| 141 | + } |
| 142 | + }); |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +} // namespace |
| 147 | +} // namespace mongo |
0 commit comments