Skip to content

Commit 8477cf3

Browse files
committed
[libc++] Optimize ranges::for_each for iterating over __trees
[libc++] Optimize std::for_each for __tree iterators
1 parent f1c1063 commit 8477cf3

File tree

10 files changed

+344
-11
lines changed

10 files changed

+344
-11
lines changed

libcxx/include/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ set(files
194194
__algorithm/simd_utils.h
195195
__algorithm/sort.h
196196
__algorithm/sort_heap.h
197+
__algorithm/specialized_algorithms.h
197198
__algorithm/stable_partition.h
198199
__algorithm/stable_sort.h
199200
__algorithm/swap_ranges.h

libcxx/include/__algorithm/for_each.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#define _LIBCPP___ALGORITHM_FOR_EACH_H
1212

1313
#include <__algorithm/for_each_segment.h>
14+
#include <__algorithm/specialized_algorithms.h>
1415
#include <__config>
1516
#include <__functional/identity.h>
1617
#include <__iterator/segmented_iterator.h>
@@ -44,6 +45,19 @@ __for_each(_SegmentedIterator __first, _SegmentedIterator __last, _Func& __func,
4445
});
4546
return __last;
4647
}
48+
49+
template <class _InputIterator,
50+
class _Func,
51+
class _Proj,
52+
__enable_if_t<__specialized_algorithm<_Algorithm::__for_each,
53+
__iterator_pair<_InputIterator, _InputIterator>>::__has_algorithm,
54+
int> = 0>
55+
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _InputIterator
56+
__for_each(_InputIterator __first, _InputIterator __last, _Func& __func, _Proj& __proj) {
57+
__specialized_algorithm<_Algorithm::__for_each, __iterator_pair<_InputIterator, _InputIterator>>()(
58+
__first, __last, __func, __proj);
59+
return __last;
60+
}
4761
#endif // !_LIBCPP_CXX03_LANG
4862

4963
template <class _InputIterator, class _Func>

libcxx/include/__algorithm/ranges_for_each.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <__algorithm/for_each.h>
1313
#include <__algorithm/for_each_n.h>
1414
#include <__algorithm/in_fun_result.h>
15+
#include <__algorithm/specialized_algorithms.h>
1516
#include <__concepts/assignable.h>
1617
#include <__config>
1718
#include <__functional/identity.h>
@@ -20,6 +21,7 @@
2021
#include <__ranges/access.h>
2122
#include <__ranges/concepts.h>
2223
#include <__ranges/dangling.h>
24+
#include <__type_traits/remove_cvref.h>
2325
#include <__utility/move.h>
2426

2527
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -71,7 +73,13 @@ struct __for_each {
7173
indirectly_unary_invocable<projected<iterator_t<_Range>, _Proj>> _Func>
7274
_LIBCPP_HIDE_FROM_ABI constexpr for_each_result<borrowed_iterator_t<_Range>, _Func>
7375
operator()(_Range&& __range, _Func __func, _Proj __proj = {}) const {
74-
return __for_each_impl(ranges::begin(__range), ranges::end(__range), __func, __proj);
76+
using _SpecialAlg = __specialized_algorithm<_Algorithm::__for_each, remove_cvref_t<_Range>>;
77+
if constexpr (_SpecialAlg::__has_algorithm) {
78+
auto [__iter, __func2] = _SpecialAlg()(__range, std::move(__func), std::move(__proj));
79+
return {std::move(__iter), std::move(__func)};
80+
} else {
81+
return __for_each_impl(ranges::begin(__range), ranges::end(__range), __func, __proj);
82+
}
7583
}
7684
};
7785

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___ALGORITHM_SPECIALIZED_ALGORITHMS_H
10+
#define _LIBCPP___ALGORITHM_SPECIALIZED_ALGORITHMS_H
11+
12+
#include <__config>
13+
14+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
15+
# pragma GCC system_header
16+
#endif
17+
18+
_LIBCPP_BEGIN_NAMESPACE_STD
19+
20+
// FIXME: This should really be an enum
21+
namespace _Algorithm {
22+
struct __for_each {};
23+
} // namespace _Algorithm
24+
25+
template <class, class>
26+
struct __iterator_pair {};
27+
28+
template <class _Alg, class _Range>
29+
struct __specialized_algorithm {
30+
static const bool __has_algorithm = false;
31+
};
32+
33+
_LIBCPP_END_NAMESPACE_STD
34+
35+
#endif // _LIBCPP___ALGORITHM_SPECIALIZED_ALGORITHMS_H

libcxx/include/__tree

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#define _LIBCPP___TREE
1212

1313
#include <__algorithm/min.h>
14+
#include <__algorithm/specialized_algorithms.h>
1415
#include <__assert>
1516
#include <__config>
1617
#include <__fwd/pair.h>
@@ -717,6 +718,59 @@ private:
717718
friend class __tree_const_iterator;
718719
};
719720

721+
template <class _Reference, class _EndNodePtr, class _NodePtr, class _Func, class _Proj>
722+
_LIBCPP_HIDE_FROM_ABI bool __tree_iterate_from_root(_EndNodePtr __last, _NodePtr __root, _Func& __func, _Proj& __proj) {
723+
if (__root->__left_) {
724+
if (std::__tree_iterate_from_root<_Reference>(__last, static_cast<_NodePtr>(__root->__left_), __func, __proj))
725+
return true;
726+
}
727+
if (__root == __last)
728+
return true;
729+
__func(static_cast<_Reference>(__root->__get_value()));
730+
if (__root->__right_)
731+
return std::__tree_iterate_from_root<_Reference>(__last, static_cast<_NodePtr>(__root->__right_), __func, __proj);
732+
return false;
733+
}
734+
735+
template <class _Reference, class _NodePtr, class _EndNodePtr, class _Func, class _Proj>
736+
_LIBCPP_HIDE_FROM_ABI void
737+
__tree_iterate_from_begin(_EndNodePtr __first, _EndNodePtr __last, _Func& __func, _Proj& __proj) {
738+
while (true) {
739+
if (__first == __last)
740+
return;
741+
auto __nfirst = static_cast<_NodePtr>(__first);
742+
__func(static_cast<_Reference>(__nfirst->__get_value()));
743+
if (__nfirst->__right_) {
744+
if (std::__tree_iterate_from_root<_Reference>(__last, static_cast<_NodePtr>(__nfirst->__right_), __func, __proj))
745+
return;
746+
}
747+
if (std::__tree_is_left_child(__nfirst)) {
748+
__first = __nfirst->__parent_;
749+
} else {
750+
do {
751+
__first = __nfirst->__parent_;
752+
} while (!std::__tree_is_left_child(__nfirst));
753+
}
754+
}
755+
}
756+
757+
#ifndef _LIBCPP_CXX03_LANG
758+
template <class _Tp, class _NodePtr, class _DiffType>
759+
struct __specialized_algorithm<
760+
_Algorithm::__for_each,
761+
__iterator_pair<__tree_iterator<_Tp, _NodePtr, _DiffType>, __tree_iterator<_Tp, _NodePtr, _DiffType>>> {
762+
static const bool __has_algorithm = true;
763+
764+
using __iterator _LIBCPP_NODEBUG = __tree_iterator<_Tp, _NodePtr, _DiffType>;
765+
766+
template <class _Func, class _Proj>
767+
_LIBCPP_HIDE_FROM_ABI static void operator()(__iterator __first, __iterator __last, _Func& __func, _Proj& __proj) {
768+
std::__tree_iterate_from_begin<typename __iterator::reference, _NodePtr>(
769+
__first.__ptr_, __last.__ptr_, __func, __proj);
770+
}
771+
};
772+
#endif
773+
720774
template <class _Tp, class _NodePtr, class _DiffType>
721775
class __tree_const_iterator {
722776
using _NodeTypes _LIBCPP_NODEBUG = __tree_node_types<_NodePtr>;
@@ -780,8 +834,28 @@ private:
780834

781835
template <class, class, class>
782836
friend class __tree;
837+
838+
friend struct __specialized_algorithm<_Algorithm::__for_each,
839+
__iterator_pair<__tree_const_iterator, __tree_const_iterator> >;
783840
};
784841

842+
#ifndef _LIBCPP_CXX03_LANG
843+
template <class _Tp, class _NodePtr, class _DiffType>
844+
struct __specialized_algorithm<
845+
_Algorithm::__for_each,
846+
__iterator_pair<__tree_const_iterator<_Tp, _NodePtr, _DiffType>, __tree_const_iterator<_Tp, _NodePtr, _DiffType>>> {
847+
static const bool __has_algorithm = true;
848+
849+
using __iterator = __tree_const_iterator<_Tp, _NodePtr, _DiffType>;
850+
851+
template <class _Func, class _Proj>
852+
_LIBCPP_HIDE_FROM_ABI static void operator()(__iterator __first, __iterator __last, _Func& __func, _Proj& __proj) {
853+
std::__tree_iterate_from_begin<typename __iterator::reference, _NodePtr>(
854+
__first.__ptr_, __last.__ptr_, __func, __proj);
855+
}
856+
};
857+
#endif
858+
785859
template <class _Tp, class _Compare>
786860
#ifndef _LIBCPP_CXX03_LANG
787861
_LIBCPP_DIAGNOSE_WARNING(!__is_invocable_v<_Compare const&, _Tp const&, _Tp const&>,
@@ -1466,7 +1540,36 @@ private:
14661540

14671541
return __dest;
14681542
}
1543+
1544+
friend struct __specialized_algorithm<_Algorithm::__for_each, __tree>;
1545+
};
1546+
1547+
#if _LIBCPP_STD_VER >= 14
1548+
template <class _Tp, class _Compare, class _Allocator>
1549+
struct __specialized_algorithm<_Algorithm::__for_each, __tree<_Tp, _Compare, _Allocator> > {
1550+
static const bool __has_algorithm = true;
1551+
1552+
using __node_pointer _LIBCPP_NODEBUG = typename __tree<_Tp, _Compare, _Allocator>::__node_pointer;
1553+
1554+
template <class _Func, class _Proj>
1555+
#ifndef _LIBCPP_COMPILER_GCC
1556+
_LIBCPP_HIDE_FROM_ABI
1557+
#endif
1558+
static void __impl(__node_pointer __root, _Func& __func, _Proj& __proj) {
1559+
if (__root->__left_)
1560+
__impl(static_cast<__node_pointer>(__root->__left_), __func, __proj);
1561+
__func(__root->__get_value());
1562+
if (__root->__right_)
1563+
__impl(static_cast<__node_pointer>(__root->__right_), __func, __proj);
1564+
}
1565+
1566+
template <class _Tree, class _Func, class _Proj>
1567+
_LIBCPP_HIDE_FROM_ABI static auto operator()(_Tree&& __range, _Func __func, _Proj __proj) {
1568+
__impl(__range.__root(), __func, __proj);
1569+
return std::make_pair(__range.end(), std::move(__func));
1570+
}
14691571
};
1572+
#endif
14701573

14711574
// Precondition: __size_ != 0
14721575
template <class _Tp, class _Compare, class _Allocator>

libcxx/include/map

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ erase_if(multimap<Key, T, Compare, Allocator>& c, Predicate pred); // C++20
577577
# include <__algorithm/equal.h>
578578
# include <__algorithm/lexicographical_compare.h>
579579
# include <__algorithm/lexicographical_compare_three_way.h>
580+
# include <__algorithm/specialized_algorithms.h>
580581
# include <__assert>
581582
# include <__config>
582583
# include <__functional/binary_function.h>
@@ -1375,6 +1376,8 @@ private:
13751376
# ifdef _LIBCPP_CXX03_LANG
13761377
_LIBCPP_HIDE_FROM_ABI __node_holder __construct_node_with_key(const key_type& __k);
13771378
# endif
1379+
1380+
friend struct __specialized_algorithm<_Algorithm::__for_each, map>;
13781381
};
13791382

13801383
# if _LIBCPP_STD_VER >= 17
@@ -1427,6 +1430,23 @@ map(initializer_list<pair<_Key, _Tp>>, _Allocator)
14271430
-> map<remove_const_t<_Key>, _Tp, less<remove_const_t<_Key>>, _Allocator>;
14281431
# endif
14291432

1433+
# if _LIBCPP_STD_VER >= 14
1434+
template <class _Key, class _Tp, class _Compare, class _Allocator>
1435+
struct __specialized_algorithm<_Algorithm::__for_each, map<_Key, _Tp, _Compare, _Allocator>> {
1436+
using __map _LIBCPP_NODEBUG = map<_Key, _Tp, _Compare, _Allocator>;
1437+
1438+
static const bool __has_algorithm = true;
1439+
1440+
// set's begin() and end() are identical with and without const qualifiaction
1441+
template <class _Map, class _Func>
1442+
_LIBCPP_HIDE_FROM_ABI static auto operator()(_Map&& __map, _Func __func) {
1443+
auto [_, __func2] = __specialized_algorithm<_Algorithm::__for_each, typename __map::__base>()(
1444+
__map.__tree_, std::move(__func));
1445+
return std::make_pair(__map.end(), std::move(__func2));
1446+
}
1447+
};
1448+
# endif
1449+
14301450
# ifndef _LIBCPP_CXX03_LANG
14311451
template <class _Key, class _Tp, class _Compare, class _Allocator>
14321452
map<_Key, _Tp, _Compare, _Allocator>::map(map&& __m, const allocator_type& __a)
@@ -1940,6 +1960,8 @@ private:
19401960

19411961
typedef __map_node_destructor<__node_allocator> _Dp;
19421962
typedef unique_ptr<__node, _Dp> __node_holder;
1963+
1964+
friend struct __specialized_algorithm<_Algorithm::__for_each, multimap>;
19431965
};
19441966

19451967
# if _LIBCPP_STD_VER >= 17
@@ -1992,6 +2014,23 @@ multimap(initializer_list<pair<_Key, _Tp>>, _Allocator)
19922014
-> multimap<remove_const_t<_Key>, _Tp, less<remove_const_t<_Key>>, _Allocator>;
19932015
# endif
19942016

2017+
# if _LIBCPP_STD_VER >= 14
2018+
template <class _Key, class _Tp, class _Compare, class _Allocator>
2019+
struct __specialized_algorithm<_Algorithm::__for_each, multimap<_Key, _Tp, _Compare, _Allocator>> {
2020+
using __map _LIBCPP_NODEBUG = multimap<_Key, _Tp, _Compare, _Allocator>;
2021+
2022+
static const bool __has_algorithm = true;
2023+
2024+
// set's begin() and end() are identical with and without const qualifiaction
2025+
template <class _Map, class _Func>
2026+
_LIBCPP_HIDE_FROM_ABI static auto operator()(_Map&& __map, _Func __func) {
2027+
auto [_, __func2] = __specialized_algorithm<_Algorithm::__for_each, typename __map::__base>()(
2028+
__map.__tree_, std::move(__func));
2029+
return std::make_pair(__map.end(), std::move(__func2));
2030+
}
2031+
};
2032+
# endif
2033+
19952034
# ifndef _LIBCPP_CXX03_LANG
19962035
template <class _Key, class _Tp, class _Compare, class _Allocator>
19972036
multimap<_Key, _Tp, _Compare, _Allocator>::multimap(multimap&& __m, const allocator_type& __a)

libcxx/include/module.modulemap.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ module std [system] {
838838
module simd_utils { header "__algorithm/simd_utils.h" }
839839
module sort_heap { header "__algorithm/sort_heap.h" }
840840
module sort { header "__algorithm/sort.h" }
841+
module specialized_algorithms { header "__algorithm/specialized_algorithms.h" }
841842
module stable_partition { header "__algorithm/stable_partition.h" }
842843
module stable_sort {
843844
header "__algorithm/stable_sort.h"

libcxx/include/set

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ erase_if(multiset<Key, Compare, Allocator>& c, Predicate pred); // C++20
518518
# include <__algorithm/equal.h>
519519
# include <__algorithm/lexicographical_compare.h>
520520
# include <__algorithm/lexicographical_compare_three_way.h>
521+
# include <__algorithm/specialized_algorithms.h>
521522
# include <__assert>
522523
# include <__config>
523524
# include <__functional/is_transparent.h>
@@ -902,6 +903,9 @@ public:
902903
return __tree_.__equal_range_multi(__k);
903904
}
904905
# endif
906+
907+
template <class, class>
908+
friend struct __specialized_algorithm;
905909
};
906910

907911
# if _LIBCPP_STD_VER >= 17
@@ -948,6 +952,21 @@ template <class _Key, class _Allocator, class = enable_if_t<__is_allocator_v<_Al
948952
set(initializer_list<_Key>, _Allocator) -> set<_Key, less<_Key>, _Allocator>;
949953
# endif
950954

955+
# if _LIBCPP_STD_VER >= 14
956+
template <class _Alg, class _Key, class _Compare, class _Allocator>
957+
struct __specialized_algorithm<_Alg, set<_Key, _Compare, _Allocator>> {
958+
using __set _LIBCPP_NODEBUG = set<_Key, _Compare, _Allocator>;
959+
960+
static const bool __has_algorithm = __specialized_algorithm<_Alg, typename __set::__base>::__has_algorithm;
961+
962+
// set's begin() and end() are identical with and without const qualifiaction
963+
template <class... _Args>
964+
_LIBCPP_HIDE_FROM_ABI static auto operator()(const __set& __set, _Args&&... __args) {
965+
return __specialized_algorithm<_Alg, typename __set::__base>()(__set.__tree_, std::forward<_Args>(__args)...);
966+
}
967+
};
968+
# endif
969+
951970
# ifndef _LIBCPP_CXX03_LANG
952971

953972
template <class _Key, class _Compare, class _Allocator>
@@ -1362,6 +1381,9 @@ public:
13621381
return __tree_.__equal_range_multi(__k);
13631382
}
13641383
# endif
1384+
1385+
template <class, class>
1386+
friend struct __specialized_algorithm;
13651387
};
13661388

13671389
# if _LIBCPP_STD_VER >= 17
@@ -1409,6 +1431,21 @@ template <class _Key, class _Allocator, class = enable_if_t<__is_allocator_v<_Al
14091431
multiset(initializer_list<_Key>, _Allocator) -> multiset<_Key, less<_Key>, _Allocator>;
14101432
# endif
14111433

1434+
# if _LIBCPP_STD_VER >= 14
1435+
template <class _Alg, class _Key, class _Compare, class _Allocator>
1436+
struct __specialized_algorithm<_Alg, multiset<_Key, _Compare, _Allocator>> {
1437+
using __set _LIBCPP_NODEBUG = multiset<_Key, _Compare, _Allocator>;
1438+
1439+
static const bool __has_algorithm = __specialized_algorithm<_Alg, typename __set::__base>::__has_algorithm;
1440+
1441+
// set's begin() and end() are identical with and without const qualifiaction
1442+
template <class... _Args>
1443+
_LIBCPP_HIDE_FROM_ABI static auto operator()(const __set& __set, _Args&&... __args) {
1444+
return __specialized_algorithm<_Alg, typename __set::__base>()(__set.__tree_, std::forward<_Args>(__args)...);
1445+
}
1446+
};
1447+
# endif
1448+
14121449
# ifndef _LIBCPP_CXX03_LANG
14131450

14141451
template <class _Key, class _Compare, class _Allocator>

0 commit comments

Comments
 (0)