1+ #pragma once
2+
3+ #include " recorders.h"
4+ #include < mutex>
5+ #include < thread>
6+ #include < vector>
7+
8+ namespace celerity ::detail {
9+ // in c++23 replace this with mdspan
10+ template <typename T>
11+ struct mpi_multidim_send_wrapper {
12+ public:
13+ const T& operator [](std::pair<int , int > ij) const {
14+ assert (ij.first * m_width + ij.second < m_data.size ());
15+ return m_data[ij.first * m_width + ij.second ];
16+ }
17+
18+ T* data () { return m_data.data (); }
19+
20+ mpi_multidim_send_wrapper (size_t width, size_t height) : m_data(width * height), m_width(width){};
21+
22+ private:
23+ std::vector<T> m_data;
24+ const size_t m_width;
25+ };
26+
27+ // Probably replace this in c++20 with span
28+ template <typename T>
29+ struct window {
30+ public:
31+ window (const std::vector<T>& value) : m_value(value) {}
32+
33+ const T& operator [](size_t i) const {
34+ assert (i >= 0 && i < m_width);
35+ return m_value[m_offset + i];
36+ }
37+
38+ size_t size () {
39+ m_width = m_value.size () - m_offset;
40+ return m_width;
41+ }
42+
43+ void slide (size_t i) {
44+ assert (i == 0 || (i >= 0 && i <= m_width));
45+ m_offset += i;
46+ m_width -= i;
47+ }
48+
49+ private:
50+ const std::vector<T>& m_value;
51+ size_t m_offset = 0 ;
52+ size_t m_width = 0 ;
53+ };
54+
55+ using task_hash = size_t ;
56+ using task_hash_data = mpi_multidim_send_wrapper<task_hash>;
57+ using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
58+
59+ class abstract_block_chain {
60+ friend struct abstract_block_chain_testspy ;
61+
62+ public:
63+ virtual void start () { m_is_running = true ; };
64+ virtual void stop () { m_is_running = false ; };
65+
66+ abstract_block_chain (const abstract_block_chain&) = delete ;
67+ abstract_block_chain& operator =(const abstract_block_chain&) = delete ;
68+ abstract_block_chain& operator =(abstract_block_chain&&) = delete ;
69+
70+ abstract_block_chain (abstract_block_chain&&) = default ;
71+ virtual ~abstract_block_chain () { stop (); }
72+
73+ abstract_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm)
74+ : m_local_nid(local_nid), m_num_nodes(num_nodes), m_sizes(num_nodes), m_task_recorder_window(task_recorder), m_comm(comm) {}
75+
76+ protected:
77+ virtual void run () = 0;
78+
79+ virtual void divergence_out (const divergence_map& check_map, const int task_num) = 0;
80+
81+ void add_new_hashes ();
82+ void clear (const int min_progress);
83+ virtual void allgather_sizes ();
84+ virtual void allgather_hashes (const int max_size, task_hash_data& data);
85+ std::pair<int , int > collect_sizes ();
86+ task_hash_data collect_hashes (const int max_size);
87+ divergence_map create_check_map (const task_hash_data& task_graphs, const int task_num) const ;
88+
89+ void check_for_deadlock () const ;
90+
91+ static void print_node_divergences (const divergence_map& check_map, const int task_num);
92+
93+ static void print_task_record (const divergence_map& check_map, const task_record& task, const task_hash hash);
94+
95+ virtual void dedub_print_task_record (const divergence_map& check_map, const int task_num) const ;
96+
97+ bool check_for_divergence ();
98+
99+ protected:
100+ node_id m_local_nid;
101+ size_t m_num_nodes;
102+
103+ std::vector<task_hash> m_hashes;
104+ std::vector<int > m_sizes;
105+
106+ bool m_is_running = true ;
107+
108+ window<task_record> m_task_recorder_window;
109+
110+ std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
111+
112+ MPI_Comm m_comm;
113+ };
114+
115+ class single_node_test_divergence_block_chain : public abstract_block_chain {
116+ public:
117+ single_node_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm,
118+ const std::vector<std::reference_wrapper<const std::vector<task_record>>>& other_task_records)
119+ : abstract_block_chain(num_nodes, local_nid, task_recorder, comm), m_other_hashes(other_task_records.size()) {
120+ for (auto & tsk_rcd : other_task_records) {
121+ m_other_task_records.push_back (window<task_record>(tsk_rcd));
122+ }
123+ }
124+
125+ private:
126+ void run () override {}
127+
128+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
129+ void allgather_sizes () override ;
130+ void allgather_hashes (const int max_size, task_hash_data& data) override ;
131+
132+ void dedub_print_task_record (const divergence_map& check_map, const int task_num) const override ;
133+
134+ std::vector<std::vector<task_hash>> m_other_hashes;
135+ std::vector<window<task_record>> m_other_task_records;
136+
137+ int m_injected_delete_size = 0 ;
138+ };
139+
140+ class distributed_test_divergence_block_chain : public abstract_block_chain {
141+ public:
142+ distributed_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
143+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {}
144+
145+ private:
146+ void run () override {}
147+
148+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
149+ };
150+
151+ class divergence_block_chain : public abstract_block_chain {
152+ public:
153+ void start () override ;
154+ void stop () override ;
155+
156+ divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
157+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {
158+ start ();
159+ }
160+
161+ private:
162+ void run () override ;
163+
164+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
165+
166+ private:
167+ std::thread m_thread;
168+ };
169+ } // namespace celerity::detail
0 commit comments