Skip to content

Commit 3a49f0f

Browse files
committed
Add async_apply for map-like objects
1 parent 8605e60 commit 3a49f0f

File tree

5 files changed

+132
-91
lines changed

5 files changed

+132
-91
lines changed

include/albatross/Indexing

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <albatross/src/indexing/subset.hpp>
2020
#include <albatross/src/indexing/filter.hpp>
2121
#include <albatross/src/indexing/apply.hpp>
22+
#include "./utils/AsyncUtils"
2223
#include <albatross/src/indexing/group_by.hpp>
2324

2425
#endif

include/albatross/src/indexing/group_by.hpp

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ template <typename KeyType, typename ValueType> class GroupedBase {
115115
return albatross::apply(map_, std::forward<ApplyFunction>(f));
116116
}
117117

118+
template <typename ApplyFunction> auto async_apply(ApplyFunction &&f) const {
119+
return albatross::async_apply(map_, std::forward<ApplyFunction>(f));
120+
}
121+
118122
protected:
119123
std::map<KeyType, ValueType> map_;
120124
};
@@ -188,20 +192,13 @@ class Grouped<KeyType, GroupIndices>
188192
using Base = GroupedBase<KeyType, GroupIndices>;
189193
using Base::Base;
190194

191-
template <typename ApplyFunction,
192-
typename ApplyType = typename details::key_value_apply_result<
193-
ApplyFunction, KeyType, GroupIndices>::type,
194-
typename std::enable_if<
195-
details::is_valid_index_apply_function<ApplyFunction, KeyType,
196-
GroupIndices>::value &&
197-
!std::is_same<void, ApplyType>::value,
198-
int>::type = 0>
199-
auto index_apply(const ApplyFunction &f) const {
200-
Grouped<KeyType, ApplyType> output;
201-
for (const auto &pair : this->map_) {
202-
output.emplace(pair.first, f(pair.first, pair.second));
203-
}
204-
return output;
195+
template <typename ApplyFunction> auto index_apply(ApplyFunction &&f) const {
196+
return apply(this->map_, std::forward<ApplyFunction>(f));
197+
}
198+
199+
template <typename ApplyFunction>
200+
auto async_index_apply(ApplyFunction &&f) const {
201+
return async_apply(this->map_, std::forward<ApplyFunction>(f));
205202
}
206203
};
207204

@@ -401,8 +398,12 @@ template <typename Derived> class GroupByBase {
401398

402399
std::size_t size() const { return indexers().size(); }
403400

404-
template <typename ApplyFunction> auto apply(const ApplyFunction &f) const {
405-
return groups().apply(f);
401+
template <typename ApplyFunction> auto apply(ApplyFunction &&f) const {
402+
return groups().apply(std::forward<ApplyFunction>(f));
403+
}
404+
405+
template <typename ApplyFunction> auto async_apply(ApplyFunction &&f) const {
406+
return groups().async_apply(std::forward<ApplyFunction>(f));
406407
}
407408

408409
ValueType get_group(const KeyType &key) const {
@@ -415,38 +416,17 @@ template <typename Derived> class GroupByBase {
415416
albatross::subset(parent_, first_indexer.second));
416417
}
417418

418-
template <typename ApplyFunction,
419-
typename ApplyType = typename details::key_value_apply_result<
420-
ApplyFunction, KeyType, GroupIndices>::type,
421-
typename std::enable_if<
422-
details::is_valid_index_apply_function<ApplyFunction, KeyType,
423-
GroupIndices>::value &&
424-
!std::is_same<void, ApplyType>::value,
425-
int>::type = 0>
426-
auto index_apply(const ApplyFunction &f) const {
427-
Grouped<KeyType, ApplyType> output;
428-
for (const auto &pair : indexers()) {
429-
output.emplace(pair.first, f(pair.first, pair.second));
430-
}
431-
return output;
419+
template <typename ApplyFunction> auto index_apply(ApplyFunction &&f) const {
420+
return albatross::apply(indexers(), std::forward<ApplyFunction>(f));
432421
}
433422

434-
template <typename ApplyFunction,
435-
typename ApplyType = typename details::key_value_apply_result<
436-
ApplyFunction, KeyType, GroupIndices>::type,
437-
typename std::enable_if<
438-
details::is_valid_index_apply_function<ApplyFunction, KeyType,
439-
GroupIndices>::value &&
440-
std::is_same<void, ApplyType>::value,
441-
int>::type = 0>
442-
void index_apply(const ApplyFunction &f) const {
443-
for (const auto &pair : indexers()) {
444-
f(pair.first, pair.second);
445-
}
423+
template <typename ApplyFunction>
424+
auto async_index_apply(ApplyFunction &&f) const {
425+
return albatross::async_apply(indexers(), std::forward<ApplyFunction>(f));
446426
}
447427

448-
template <typename FilterFunction> auto filter(FilterFunction f) const {
449-
return groups().filter(f);
428+
template <typename FilterFunction> auto filter(FilterFunction &&f) const {
429+
return groups().filter(std::forward<FilterFunction>(f));
450430
}
451431

452432
Grouped<KeyType, std::size_t> counts() const {

include/albatross/src/utils/async_utils.hpp

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ template <typename ValueType, typename ApplyFunction,
3232
ApplyFunction, ValueType>::value &&
3333
std::is_same<void, ApplyType>::value,
3434
int>::type = 0>
35-
void async_apply(const std::vector<ValueType> &xs, const ApplyFunction &f) {
35+
inline void async_apply(const std::vector<ValueType> &xs,
36+
const ApplyFunction &f) {
3637
std::vector<std::future<void>> futures;
3738
for (const auto &x : xs) {
3839
futures.emplace_back(async_safe(f, x));
@@ -49,7 +50,8 @@ template <typename ValueType, typename ApplyFunction,
4950
ApplyFunction, ValueType>::value &&
5051
!std::is_same<void, ApplyType>::value,
5152
int>::type = 0>
52-
auto async_apply(const std::vector<ValueType> &xs, const ApplyFunction &f) {
53+
inline auto async_apply(const std::vector<ValueType> &xs,
54+
const ApplyFunction &f) {
5355
std::vector<std::future<ApplyType>> futures;
5456
for (const auto &x : xs) {
5557
futures.emplace_back(async_safe(f, x));
@@ -62,6 +64,92 @@ auto async_apply(const std::vector<ValueType> &xs, const ApplyFunction &f) {
6264
return output;
6365
}
6466

67+
// Map
68+
69+
template <
70+
template <typename...> class Map, typename KeyType, typename ValueType,
71+
typename ApplyFunction,
72+
typename ApplyType = typename details::key_value_apply_result<
73+
ApplyFunction, KeyType, ValueType>::type,
74+
typename std::enable_if<details::is_valid_key_value_apply_function<
75+
ApplyFunction, KeyType, ValueType>::value &&
76+
std::is_same<void, ApplyType>::value,
77+
int>::type = 0>
78+
inline void async_apply(const Map<KeyType, ValueType> &map, ApplyFunction &&f) {
79+
std::vector<std::future<void>> futures;
80+
for (const auto &pair : map) {
81+
futures.emplace_back(async_safe(f, pair.first, pair.second));
82+
}
83+
for (auto &f : futures) {
84+
f.get();
85+
}
86+
}
87+
88+
template <
89+
template <typename...> class Map, typename KeyType, typename ValueType,
90+
typename ApplyFunction,
91+
typename ApplyType = typename details::key_value_apply_result<
92+
ApplyFunction, KeyType, ValueType>::type,
93+
typename std::enable_if<details::is_valid_key_value_apply_function<
94+
ApplyFunction, KeyType, ValueType>::value &&
95+
!std::is_same<void, ApplyType>::value,
96+
int>::type = 0>
97+
inline Grouped<KeyType, ApplyType>
98+
async_apply(const Map<KeyType, ValueType> &map, ApplyFunction &&f) {
99+
100+
std::map<KeyType, std::future<ApplyType>> futures;
101+
for (const auto &pair : map) {
102+
futures[pair.first] = async_safe(f, pair.first, pair.second);
103+
}
104+
105+
Grouped<KeyType, ApplyType> output;
106+
for (auto &pair : futures) {
107+
output.emplace(pair.first, pair.second.get());
108+
}
109+
return output;
110+
}
111+
112+
template <template <typename...> class Map, typename KeyType,
113+
typename ValueType, typename ApplyFunction,
114+
typename ApplyType = typename details::value_only_apply_result<
115+
ApplyFunction, ValueType>::type,
116+
typename std::enable_if<details::is_valid_value_only_apply_function<
117+
ApplyFunction, ValueType>::value &&
118+
!std::is_same<void, ApplyType>::value,
119+
int>::type = 0>
120+
inline Grouped<KeyType, ApplyType>
121+
async_apply(const Map<KeyType, ValueType> &map, ApplyFunction &&f) {
122+
123+
std::map<KeyType, std::future<ApplyType>> futures;
124+
for (const auto &pair : map) {
125+
futures[pair.first] = async_safe(f, pair.second);
126+
}
127+
128+
Grouped<KeyType, ApplyType> output;
129+
for (auto &pair : futures) {
130+
output.emplace(pair.first, pair.second.get());
131+
}
132+
return output;
133+
}
134+
135+
template <template <typename...> class Map, typename KeyType,
136+
typename ValueType, typename ApplyFunction,
137+
typename ApplyType = typename details::value_only_apply_result<
138+
ApplyFunction, ValueType>::type,
139+
typename std::enable_if<details::is_valid_value_only_apply_function<
140+
ApplyFunction, ValueType>::value &&
141+
std::is_same<void, ApplyType>::value,
142+
int>::type = 0>
143+
inline void async_apply(const Map<KeyType, ValueType> &map, ApplyFunction &&f) {
144+
std::vector<std::future<void>> futures;
145+
for (const auto &pair : map) {
146+
futures.emplace_back(async_safe(f, pair.second));
147+
}
148+
for (auto &f : futures) {
149+
f.get();
150+
}
151+
}
152+
65153
} // namespace albatross
66154

67155
#endif /* INCLUDE_ALBATROSS_SRC_UTILS_ASYNC_UTILS_HPP_ */

tests/CMakeLists.txt

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,5 @@
11
add_executable(albatross_unit_tests
2-
test_apply.cc
3-
test_async_utils.cc
4-
test_block_utils.cc
5-
test_call_trace.cc
6-
test_callers.cc
7-
test_concatenate.cc
8-
test_core_dataset.cc
9-
test_core_distribution.cc
10-
test_core_model.cc
11-
test_covariance_function.cc
12-
test_covariance_functions.cc
13-
test_cross_validation.cc
14-
test_csv_utils.cc
15-
test_distance_metrics.cc
16-
test_eigen_utils.cc
17-
test_evaluate.cc
18-
test_gp.cc
192
test_group_by.cc
20-
test_indexing.cc
21-
test_linalg_utils.cc
22-
test_map_utils.cc
23-
test_model_adapter.cc
24-
test_model_metrics.cc
25-
test_models.cc
26-
test_parameter_handling_mixin.cc
27-
test_patchwork_gp.cc
28-
test_prediction.cc
29-
test_radial.cc
30-
test_random_utils.cc
31-
test_ransac.cc
32-
test_samplers.cc
33-
test_scaling_function.cc
34-
test_serializable_ldlt.cc
35-
test_serialize.cc
36-
test_sparse_gp.cc
37-
test_stats.cc
38-
test_traits_cereal.cc
39-
test_traits_core.cc
40-
test_traits_details.cc
41-
test_traits_covariance_functions.cc
42-
test_traits_evaluation.cc
43-
test_traits_indexing.cc
44-
test_tune.cc
45-
test_variant_utils.cc
463
)
474
target_include_directories(albatross_unit_tests SYSTEM PRIVATE
485
"${gtest_SOURCE_DIR}"

tests/test_group_by.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ TYPED_TEST_P(GroupByTester, test_groupby_apply_combine) {
283283

284284
const auto combined = grouped.apply(only_keep_one).combine();
285285

286+
const auto async_combined = grouped.async_apply(only_keep_one).combine();
287+
288+
// Async an normal apply should match
289+
EXPECT_EQ(combined, async_combined);
290+
286291
// Same number of final combined elements as there are groups.
287292
EXPECT_EQ(grouped.size(), combined.size());
288293
}
@@ -296,7 +301,11 @@ TYPED_TEST_P(GroupByTester, test_groupby_apply_void) {
296301
const auto increment_count = [&](const auto &, const auto &) { ++count; };
297302

298303
grouped.apply(increment_count);
304+
EXPECT_EQ(grouped.size(), count);
299305

306+
// Again but with async
307+
count = 0;
308+
grouped.async_apply(increment_count);
300309
EXPECT_EQ(grouped.size(), count);
301310
}
302311

@@ -305,11 +314,14 @@ TYPED_TEST_P(GroupByTester, test_groupby_apply_value_only) {
305314
const auto grouped = group_by(parent, this->test_case.get_grouper());
306315

307316
std::size_t count = 0;
308-
309317
const auto increment_count = [&](const auto &) { ++count; };
310318

311319
grouped.apply(increment_count);
320+
EXPECT_EQ(grouped.size(), count);
312321

322+
// Again but with async
323+
count = 0;
324+
grouped.async_apply(increment_count);
313325
EXPECT_EQ(grouped.size(), count);
314326
}
315327

@@ -318,13 +330,16 @@ TYPED_TEST_P(GroupByTester, test_groupby_index_apply) {
318330
const auto grouped = group_by(parent, this->test_case.get_grouper());
319331

320332
std::size_t count = 0;
321-
322333
const auto increment_count = [&](const auto &, const GroupIndices &) {
323334
++count;
324335
};
325336

326337
grouped.index_apply(increment_count);
338+
EXPECT_EQ(grouped.size(), count);
327339

340+
// Again but with async
341+
count = 0;
342+
grouped.async_index_apply(increment_count);
328343
EXPECT_EQ(grouped.size(), count);
329344
}
330345

0 commit comments

Comments
 (0)