Skip to content

Commit 2230e24

Browse files
psalzGagaLP
authored andcommitted
Convert all task graph tests to builder pattern + query
1 parent df60b9b commit 2230e24

File tree

7 files changed

+351
-604
lines changed

7 files changed

+351
-604
lines changed

test/accessor_tests.cc

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <celerity.h>
88

9+
#include "task_graph_test_utils.h"
910
#include "test_utils.h"
1011

1112
namespace celerity {
@@ -230,34 +231,15 @@ namespace detail {
230231
}
231232

232233
TEST_CASE("conflicts between producer-accessors and reductions are reported", "[task-manager]") {
233-
test_utils::task_test_context tt;
234-
235-
auto buf_0 = tt.mbf.create_buffer(range<1>{1});
236-
237-
CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_conflict)>(tt.tm, [&](handler& cgh) {
238-
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
239-
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
240-
}));
241-
242-
CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
243-
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
244-
buf_0.get_access<access_mode::read>(cgh, fixed<1>({0, 1}));
245-
}));
246-
247-
CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
248-
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
249-
buf_0.get_access<access_mode::write>(cgh, fixed<1>({0, 1}));
250-
}));
251-
252-
CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
253-
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
254-
buf_0.get_access<access_mode::read_write>(cgh, fixed<1>({0, 1}));
255-
}));
256-
257-
CHECK_THROWS(test_utils::add_compute_task<class UKN(task_reduction_access_conflict)>(tt.tm, [&](handler& cgh) {
258-
test_utils::add_reduction(cgh, tt.mrf, buf_0, false);
259-
buf_0.get_access<access_mode::discard_write>(cgh, fixed<1>({0, 1}));
260-
}));
234+
test_utils::tdag_test_context tctx(1 /* num_collective_nodes */);
235+
236+
auto buf_0 = tctx.create_buffer(range<1>{1});
237+
238+
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).reduce(buf_0, false).submit());
239+
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).read(buf_0, all{}).submit());
240+
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).write(buf_0, all{}).submit());
241+
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).read_write(buf_0, all{}).submit());
242+
CHECK_THROWS(tctx.device_compute(range<1>{ones}).reduce(buf_0, false).discard_write(buf_0, all{}).submit());
261243
}
262244

263245
template <access_mode>

test/debug_naming_tests.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <celerity.h>
99

10+
#include "task_graph_test_utils.h"
1011
#include "test_utils.h"
1112

1213
using namespace celerity;
@@ -15,34 +16,30 @@ using namespace celerity::detail;
1516
TEST_CASE("debug names can be set and retrieved from tasks", "[debug]") {
1617
const std::string task_name = "sample task";
1718

18-
auto tt = test_utils::task_test_context{};
19+
test_utils::tdag_test_context tctx(1 /* num_collective_nodes */);
1920

2021
SECTION("Host Task") {
21-
const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { celerity::debug::set_task_name(cgh, task_name); });
22+
const auto tid_a = tctx.master_node_host_task().name(task_name).submit();
23+
const auto tid_b = tctx.master_node_host_task().submit();
2224

23-
const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {});
24-
25-
CHECK(test_utils::get_task(tt.tdag, tid_a)->get_debug_name() == task_name);
26-
CHECK(test_utils::get_task(tt.tdag, tid_b)->get_debug_name().empty());
25+
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_a)->get_debug_name() == task_name);
26+
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_b)->get_debug_name().empty());
2727
}
2828

2929
SECTION("Compute Task") {
30-
const auto tid_a = test_utils::add_compute_task<class compute_task>(tt.tm, [&](handler& cgh) { celerity::debug::set_task_name(cgh, task_name); });
31-
32-
const auto tid_b = test_utils::add_compute_task<class compute_task_unnamed>(tt.tm, [&](handler& cgh) {});
30+
const auto tid_a = tctx.device_compute(range<1>(ones)).name(task_name).submit();
31+
const auto tid_b = tctx.device_compute<class compute_task_unnamed>(range<1>(ones)).submit();
3332

34-
CHECK(test_utils::get_task(tt.tdag, tid_a)->get_debug_name() == task_name);
35-
CHECK_THAT(test_utils::get_task(tt.tdag, tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("compute_task_unnamed"));
33+
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_a)->get_debug_name() == task_name);
34+
CHECK_THAT(test_utils::get_task(tctx.get_task_graph(), tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("compute_task_unnamed"));
3635
}
3736

3837
SECTION("ND Range Task") {
39-
const auto tid_a =
40-
test_utils::add_nd_range_compute_task<class nd_range_task>(tt.tm, [&](handler& cgh) { celerity::debug::set_task_name(cgh, task_name); });
41-
42-
const auto tid_b = test_utils::add_compute_task<class nd_range_task_unnamed>(tt.tm, [&](handler& cgh) {});
38+
const auto tid_a = tctx.device_compute(nd_range<1>{range<1>{1}, range<1>{1}}).name(task_name).submit();
39+
const auto tid_b = tctx.device_compute<class nd_range_task_unnamed>(nd_range<1>{range<1>{1}, range<1>{1}}).submit();
4340

44-
CHECK(test_utils::get_task(tt.tdag, tid_a)->get_debug_name() == task_name);
45-
CHECK_THAT(test_utils::get_task(tt.tdag, tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("nd_range_task_unnamed"));
41+
CHECK(test_utils::get_task(tctx.get_task_graph(), tid_a)->get_debug_name() == task_name);
42+
CHECK_THAT(test_utils::get_task(tctx.get_task_graph(), tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("nd_range_task_unnamed"));
4643
}
4744
}
4845

test/graph_test_utils.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class task_builder {
2828
class step {
2929
public:
3030
step(TestContext& tctx, action command, std::vector<action> requirements = {})
31-
: m_tctx(tctx), m_command(std::move(command)), m_requirements(std::move(requirements)), m_uncaught_exceptions_before(std::uncaught_exceptions()) {}
31+
: m_tctx(&tctx), m_command(std::move(command)), m_requirements(std::move(requirements)), m_uncaught_exceptions_before(std::uncaught_exceptions()) {}
3232

3333
~step() noexcept(false) { // NOLINT(bugprone-exception-escape)
3434
if(std::uncaught_exceptions() == m_uncaught_exceptions_before && (m_command || !m_requirements.empty())) {
@@ -37,13 +37,13 @@ class task_builder {
3737
}
3838

3939
step(const step&) = delete;
40-
step(step&&) = delete;
40+
step(step&&) = default;
4141
step& operator=(const step&) = delete;
42-
step& operator=(step&&) = delete;
42+
step& operator=(step&&) = default;
4343

4444
task_id submit() {
4545
assert(m_command);
46-
const auto tid = m_tctx.submit_command_group([this](handler& cgh) {
46+
const auto tid = m_tctx->submit_command_group([this](handler& cgh) {
4747
for(auto& a : m_requirements) {
4848
a(cgh);
4949
}
@@ -78,10 +78,15 @@ class task_builder {
7878
return chain<step>([&buf, rmfn](handler& cgh) { buf.template get_access<access_mode::discard_write>(cgh, rmfn); });
7979
}
8080

81+
template <typename BufferT, typename RangeMapper>
82+
step discard_read_write(BufferT& buf, RangeMapper rmfn) {
83+
return chain<step>([&buf, rmfn](handler& cgh) { buf.template get_access<access_mode::discard_read_write>(cgh, rmfn); });
84+
}
85+
8186
template <typename BufferT>
8287
inline step reduce(BufferT& buf, const bool include_current_buffer_value) {
8388
return chain<step>([this, &buf, include_current_buffer_value](
84-
handler& cgh) { add_reduction(cgh, m_tctx.create_reduction(buf.get_id(), include_current_buffer_value)); });
89+
handler& cgh) { add_reduction(cgh, m_tctx->create_reduction(buf.get_id(), include_current_buffer_value)); });
8590
}
8691

8792
template <typename HostObjT>
@@ -107,7 +112,7 @@ class task_builder {
107112
}
108113

109114
private:
110-
TestContext& m_tctx;
115+
TestContext* m_tctx;
111116
action m_command;
112117
std::vector<action> m_requirements;
113118
int m_uncaught_exceptions_before;
@@ -121,7 +126,7 @@ class task_builder {
121126
auto command = std::move(m_command);
122127
m_requirements = {};
123128
m_command = {};
124-
return StepT{m_tctx, std::move(command), std::move(requirements)};
129+
return StepT{*m_tctx, std::move(command), std::move(requirements)};
125130
}
126131
};
127132

test/print_graph_tests.cc

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "command_graph_generator_test_utils.h"
77
#include "instruction_graph_test_utils.h"
8+
#include "task_graph_test_utils.h"
89
#include "test_utils.h"
910

1011
using namespace celerity;
@@ -14,24 +15,18 @@ using namespace celerity::test_utils;
1415
namespace acc = celerity::access;
1516

1617
TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") {
17-
auto tt = test_utils::task_test_context{};
18+
tdag_test_context tctx(1 /* num_collective_nodes */);
1819

19-
auto range = celerity::range<1>(64);
20-
auto buf_0 = tt.mbf.create_buffer(range);
21-
auto buf_1 = tt.mbf.create_buffer(celerity::range<1>(1));
20+
const auto range = celerity::range<1>(64);
21+
auto buf_0 = tctx.create_buffer(range);
22+
auto buf_1 = tctx.create_buffer(celerity::range<1>(1));
2223

2324
// graph copied from graph_gen_reduction_tests "command_graph_generator generates reduction command trees"
2425

25-
test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_1.get_access<access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);
26-
test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_0.get_access<access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);
27-
test_utils::add_compute_task(
28-
tt.tm,
29-
[&](handler& cgh) {
30-
buf_0.get_access<access_mode::read>(cgh, acc::one_to_one{});
31-
test_utils::add_reduction(cgh, tt.mrf, buf_1, true /* include_current_buffer_value */);
32-
},
33-
range);
34-
test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_1.get_access<access_mode::read>(cgh, acc::fixed<1>({0, 1})); }, range);
26+
tctx.device_compute(range).discard_write(buf_1, acc::one_to_one{}).submit();
27+
tctx.device_compute(range).discard_write(buf_0, acc::one_to_one{}).submit();
28+
tctx.device_compute(range).read(buf_0, acc::one_to_one{}).reduce(buf_1, true /* include_current_buffer_value */).submit();
29+
tctx.device_compute(range).read(buf_1, acc::fixed<1>({0, 1})).submit();
3530

3631
// Smoke test: It is valid for the dot output to change with updates to graph generation. If this test fails, verify that the printed graph is sane and
3732
// replace the `expected` value with the new dot graph.
@@ -42,7 +37,7 @@ TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") {
4237
"<i>read_write</i> B1 {[0,0,0] - [1,1,1]}<br/><i>read</i> B0 {[0,0,0] - [64,1,1]}>];4[shape=box style=rounded label=<T4<br/><b>device-compute</b> "
4338
"[0,0,0] + [64,1,1]<br/><i>read</i> B1 {[0,0,0] - [1,1,1]}>];0->1[color=orchid];0->2[color=orchid];1->3[];2->3[];3->4[];}";
4439

45-
const auto dot = print_task_graph(tt.trec);
40+
const auto dot = print_task_graph(tctx.get_task_recorder());
4641
CHECK(dot == expected);
4742
if(dot != expected) { fmt::print("\n{}:\n\ngot:\n\n{}\n\nexpected:\n\n{}\n\n", Catch::getResultCapture().getCurrentTestName(), dot, expected); }
4843
}
@@ -358,13 +353,13 @@ template <int X>
358353
class name_class {};
359354

360355
TEST_CASE("task-graph names are escaped", "[print_graph][task-graph][task-name]") {
361-
auto tt = test_utils::task_test_context{};
356+
test_utils::tdag_test_context tctx(1 /* num_collective_nodes */);
362357

363-
auto range = celerity::range<1>(64);
364-
auto buf = tt.mbf.create_buffer(range);
358+
const auto range = celerity::range<1>(64);
359+
auto buf = tctx.create_buffer(range);
365360

366-
test_utils::add_compute_task<name_class<5>>(tt.tm, [&](handler& cgh) { buf.get_access<access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);
361+
tctx.device_compute<name_class<5>>(range).discard_write(buf, acc::one_to_one{}).submit();
367362

368363
const auto* escaped_name = "\"name_class&lt;...&gt;\"";
369-
REQUIRE_THAT(print_task_graph(tt.trec), Catch::Matchers::ContainsSubstring(escaped_name));
364+
REQUIRE_THAT(print_task_graph(tctx.get_task_recorder()), Catch::Matchers::ContainsSubstring(escaped_name));
370365
}

0 commit comments

Comments
 (0)