55
66#include " command_graph_generator_test_utils.h"
77#include " instruction_graph_test_utils.h"
8+ #include " task_graph_test_utils.h"
89#include " test_utils.h"
910
1011using namespace celerity ;
@@ -14,24 +15,18 @@ using namespace celerity::test_utils;
1415namespace acc = celerity::access;
1516
1617TEST_CASE (" task-graph printing is unchanged" , " [print_graph][task-graph]" ) {
17- auto tt = test_utils::task_test_context{} ;
18+ tdag_test_context tctx ( 1 /* num_collective_nodes */ ) ;
1819
19- auto range = celerity::range<1 >(64 );
20- auto buf_0 = tt. mbf .create_buffer (range);
21- auto buf_1 = tt. mbf .create_buffer (celerity::range<1 >(1 ));
20+ const auto range = celerity::range<1 >(64 );
21+ auto buf_0 = tctx .create_buffer (range);
22+ auto buf_1 = tctx .create_buffer (celerity::range<1 >(1 ));
2223
2324 // graph copied from graph_gen_reduction_tests "command_graph_generator generates reduction command trees"
2425
25- test_utils::add_compute_task (tt.tm , [&](handler& cgh) { buf_1.get_access <access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);
26- test_utils::add_compute_task (tt.tm , [&](handler& cgh) { buf_0.get_access <access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);
27- test_utils::add_compute_task (
28- tt.tm ,
29- [&](handler& cgh) {
30- buf_0.get_access <access_mode::read>(cgh, acc::one_to_one{});
31- test_utils::add_reduction (cgh, tt.mrf , buf_1, true /* include_current_buffer_value */ );
32- },
33- range);
34- test_utils::add_compute_task (tt.tm , [&](handler& cgh) { buf_1.get_access <access_mode::read>(cgh, acc::fixed<1 >({0 , 1 })); }, range);
26+ tctx.device_compute (range).discard_write (buf_1, acc::one_to_one{}).submit ();
27+ tctx.device_compute (range).discard_write (buf_0, acc::one_to_one{}).submit ();
28+ tctx.device_compute (range).read (buf_0, acc::one_to_one{}).reduce (buf_1, true /* include_current_buffer_value */ ).submit ();
29+ tctx.device_compute (range).read (buf_1, acc::fixed<1 >({0 , 1 })).submit ();
3530
3631 // Smoke test: It is valid for the dot output to change with updates to graph generation. If this test fails, verify that the printed graph is sane and
3732 // replace the `expected` value with the new dot graph.
@@ -42,7 +37,7 @@ TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") {
4237 " <i>read_write</i> B1 {[0,0,0] - [1,1,1]}<br/><i>read</i> B0 {[0,0,0] - [64,1,1]}>];4[shape=box style=rounded label=<T4<br/><b>device-compute</b> "
4338 " [0,0,0] + [64,1,1]<br/><i>read</i> B1 {[0,0,0] - [1,1,1]}>];0->1[color=orchid];0->2[color=orchid];1->3[];2->3[];3->4[];}" ;
4439
45- const auto dot = print_task_graph (tt. trec );
40+ const auto dot = print_task_graph (tctx. get_task_recorder () );
4641 CHECK (dot == expected);
4742 if (dot != expected) { fmt::print (" \n {}:\n\n got:\n\n {}\n\n expected:\n\n {}\n\n " , Catch::getResultCapture ().getCurrentTestName (), dot, expected); }
4843}
@@ -358,13 +353,13 @@ template <int X>
358353class name_class {};
359354
360355TEST_CASE (" task-graph names are escaped" , " [print_graph][task-graph][task-name]" ) {
361- auto tt = test_utils::task_test_context{} ;
356+ test_utils::tdag_test_context tctx ( 1 /* num_collective_nodes */ ) ;
362357
363- auto range = celerity::range<1 >(64 );
364- auto buf = tt. mbf .create_buffer (range);
358+ const auto range = celerity::range<1 >(64 );
359+ auto buf = tctx .create_buffer (range);
365360
366- test_utils::add_compute_task <name_class<5 >>(tt. tm , [&](handler& cgh) { buf. get_access <access_mode:: discard_write>(cgh , acc::one_to_one{}); }, range );
361+ tctx. device_compute <name_class<5 >>(range). discard_write (buf , acc::one_to_one{}). submit ( );
367362
368363 const auto * escaped_name = " \" name_class<...>\" " ;
369- REQUIRE_THAT (print_task_graph (tt. trec ), Catch::Matchers::ContainsSubstring (escaped_name));
364+ REQUIRE_THAT (print_task_graph (tctx. get_task_recorder () ), Catch::Matchers::ContainsSubstring (escaped_name));
370365}
0 commit comments