Skip to content

Commit 13c5753

Browse files
Merge pull request #11663 from felipepiovezan/felipe/adt_extras
🍒[ADT] Cherry-pick helpers for accumulate / sum_of
2 parents 0d00797 + 4f1ea07 commit 13c5753

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <iterator>
3636
#include <limits>
3737
#include <memory>
38+
#include <numeric>
3839
#include <optional>
3940
#include <tuple>
4041
#include <type_traits>
@@ -1731,6 +1732,34 @@ template <typename R> constexpr size_t range_size(R &&Range) {
17311732
return static_cast<size_t>(std::distance(adl_begin(Range), adl_end(Range)));
17321733
}
17331734

1735+
/// Wrapper for std::accumulate.
1736+
template <typename R, typename E> auto accumulate(R &&Range, E &&Init) {
1737+
return std::accumulate(adl_begin(Range), adl_end(Range),
1738+
std::forward<E>(Init));
1739+
}
1740+
1741+
/// Wrapper for std::accumulate with a binary operator.
1742+
template <typename R, typename E, typename BinaryOp>
1743+
auto accumulate(R &&Range, E &&Init, BinaryOp &&Op) {
1744+
return std::accumulate(adl_begin(Range), adl_end(Range),
1745+
std::forward<E>(Init), std::forward<BinaryOp>(Op));
1746+
}
1747+
1748+
/// Returns the sum of all values in `Range` with `Init` initial value.
1749+
/// The default initial value is 0.
1750+
template <typename R, typename E = detail::ValueOfRange<R>>
1751+
auto sum_of(R &&Range, E Init = E{0}) {
1752+
return accumulate(std::forward<R>(Range), std::move(Init));
1753+
}
1754+
1755+
/// Returns the product of all values in `Range` with `Init` initial value.
1756+
/// The default initial value is 1.
1757+
template <typename R, typename E = detail::ValueOfRange<R>>
1758+
auto product_of(R &&Range, E Init = E{1}) {
1759+
return accumulate(std::forward<R>(Range), std::move(Init),
1760+
std::multiplies<>{});
1761+
}
1762+
17341763
/// Provide wrappers to std::for_each which take ranges instead of having to
17351764
/// pass begin/end explicitly.
17361765
template <typename R, typename UnaryFunction>

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <array>
1515
#include <climits>
1616
#include <cstddef>
17+
#include <functional>
1718
#include <initializer_list>
1819
#include <iterator>
1920
#include <list>
@@ -1602,6 +1603,64 @@ TEST(STLExtrasTest, Fill) {
16021603
EXPECT_THAT(V2, ElementsAre(Val, Val, Val, Val));
16031604
}
16041605

1606+
TEST(STLExtrasTest, Accumulate) {
1607+
EXPECT_EQ(accumulate(std::vector<int>(), 0), 0);
1608+
EXPECT_EQ(accumulate(std::vector<int>(), 3), 3);
1609+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1610+
EXPECT_EQ(accumulate(V1, 0), std::accumulate(V1.begin(), V1.end(), 0));
1611+
EXPECT_EQ(accumulate(V1, 10), std::accumulate(V1.begin(), V1.end(), 10));
1612+
EXPECT_EQ(accumulate(drop_begin(V1), 7),
1613+
std::accumulate(V1.begin() + 1, V1.end(), 7));
1614+
1615+
EXPECT_EQ(accumulate(V1, 2, std::multiplies<>{}), 240);
1616+
}
1617+
1618+
TEST(STLExtrasTest, SumOf) {
1619+
EXPECT_EQ(sum_of(std::vector<int>()), 0);
1620+
EXPECT_EQ(sum_of(std::vector<int>(), 1), 1);
1621+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1622+
static_assert(std::is_same_v<decltype(sum_of(V1)), int>);
1623+
static_assert(std::is_same_v<decltype(sum_of(V1, 1)), int>);
1624+
EXPECT_EQ(sum_of(V1), 15);
1625+
EXPECT_EQ(sum_of(V1, 1), 16);
1626+
1627+
std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
1628+
static_assert(std::is_same_v<decltype(sum_of(V2)), float>);
1629+
static_assert(std::is_same_v<decltype(sum_of(V2), 1.0f), float>);
1630+
static_assert(std::is_same_v<decltype(sum_of(V2), 1.0), double>);
1631+
EXPECT_EQ(sum_of(V2), 7.0f);
1632+
EXPECT_EQ(sum_of(V2, 1.0f), 8.0f);
1633+
1634+
// Make sure that for a const argument the return value is non-const.
1635+
const std::vector<float> V3 = {1.0f, 2.0f};
1636+
static_assert(std::is_same_v<decltype(sum_of(V3)), float>);
1637+
EXPECT_EQ(sum_of(V3), 3.0f);
1638+
}
1639+
1640+
TEST(STLExtrasTest, ProductOf) {
1641+
EXPECT_EQ(product_of(std::vector<int>()), 1);
1642+
EXPECT_EQ(product_of(std::vector<int>(), 0), 0);
1643+
EXPECT_EQ(product_of(std::vector<int>(), 1), 1);
1644+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1645+
static_assert(std::is_same_v<decltype(product_of(V1)), int>);
1646+
static_assert(std::is_same_v<decltype(product_of(V1, 1)), int>);
1647+
EXPECT_EQ(product_of(V1), 120);
1648+
EXPECT_EQ(product_of(V1, 1), 120);
1649+
EXPECT_EQ(product_of(V1, 2), 240);
1650+
1651+
std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
1652+
static_assert(std::is_same_v<decltype(product_of(V2)), float>);
1653+
static_assert(std::is_same_v<decltype(product_of(V2), 1.0f), float>);
1654+
static_assert(std::is_same_v<decltype(product_of(V2), 1.0), double>);
1655+
EXPECT_EQ(product_of(V2), 8.0f);
1656+
EXPECT_EQ(product_of(V2, 4.0f), 32.0f);
1657+
1658+
// Make sure that for a const argument the return value is non-const.
1659+
const std::vector<float> V3 = {1.0f, 2.0f};
1660+
static_assert(std::is_same_v<decltype(product_of(V3)), float>);
1661+
EXPECT_EQ(product_of(V3), 2.0f);
1662+
}
1663+
16051664
struct Foo;
16061665
struct Bar {};
16071666

0 commit comments

Comments
 (0)