Skip to content

Commit f220384

Browse files
allnesaobolensk
andauthored
Increase coverage (learning-process#498)
Co-authored-by: Arseniy Obolenskiy <gooddoog@student.su>
1 parent de81c9f commit f220384

File tree

8 files changed

+374
-284
lines changed

8 files changed

+374
-284
lines changed

modules/core/performance/tests/perf_tests.cpp

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class GetStringTaskTypeTest : public ::testing::TestWithParam<TaskTypeTestCase>
141141
(*j)["tasks"]["tbb"] = "TBB";
142142
(*j)["tasks"]["seq"] = "SEQ";
143143

144-
std::ofstream(temp_path) << (*j).dump();
144+
std::ofstream(temp_path) << j->dump();
145145
}
146146

147147
void TearDown() override { std::filesystem::remove(temp_path); }
@@ -238,13 +238,8 @@ TEST(TaskTest, GetDynamicTypeReturnsCorrectEnum) {
238238
}
239239

240240
TEST(TaskTest, DestructorTerminatesIfWrongOrder) {
241-
testing::FLAGS_gtest_death_test_style = "threadsafe";
242-
ASSERT_DEATH_IF_SUPPORTED(
243-
{
244-
DummyTask task;
245-
task.Run();
246-
},
247-
"");
241+
DummyTask task;
242+
EXPECT_THROW(task.Run(), std::runtime_error);
248243
}
249244

250245
namespace my {
@@ -263,15 +258,69 @@ using TestTypes = ::testing::Types<my::nested::Type, my::Another, int>;
263258
TYPED_TEST_SUITE(GetNamespaceTest, TestTypes);
264259

265260
TYPED_TEST(GetNamespaceTest, ExtractsNamespaceCorrectly) {
266-
constexpr std::string_view kNs = ppc::util::GetNamespace<TypeParam>();
261+
std::string k_ns = ppc::util::GetNamespace<TypeParam>();
267262

268263
if constexpr (std::is_same_v<TypeParam, my::nested::Type>) {
269-
EXPECT_EQ(kNs, "my::nested");
264+
EXPECT_EQ(k_ns, "my::nested");
270265
} else if constexpr (std::is_same_v<TypeParam, my::Another>) {
271-
EXPECT_EQ(kNs, "my");
266+
EXPECT_EQ(k_ns, "my");
272267
} else if constexpr (std::is_same_v<TypeParam, int>) {
273-
EXPECT_EQ(kNs, "");
268+
EXPECT_EQ(k_ns, "");
274269
} else {
275270
FAIL() << "Unhandled type in test";
276271
}
277272
}
273+
274+
TEST(PerfTest, PipelineRunAndTaskRun) {
275+
auto task_ptr = std::make_shared<DummyTask>();
276+
ppc::core::Perf<int, int> perf(task_ptr);
277+
278+
ppc::core::PerfAttr attr;
279+
double time = 0.0;
280+
attr.num_running = 2;
281+
attr.current_timer = [&time]() {
282+
double t = time;
283+
time += 1.0;
284+
return t;
285+
};
286+
287+
EXPECT_NO_THROW(perf.PipelineRun(attr));
288+
auto res_pipeline = perf.GetPerfResults();
289+
EXPECT_EQ(res_pipeline.type_of_running, ppc::core::PerfResults::kPipeline);
290+
EXPECT_GT(res_pipeline.time_sec, 0.0);
291+
292+
EXPECT_NO_THROW(perf.TaskRun(attr));
293+
auto res_taskrun = perf.GetPerfResults();
294+
EXPECT_EQ(res_taskrun.type_of_running, ppc::core::PerfResults::kTaskRun);
295+
EXPECT_GT(res_taskrun.time_sec, 0.0);
296+
}
297+
298+
TEST(PerfTest, PrintPerfStatisticThrowsOnNone) {
299+
{
300+
auto task_ptr = std::make_shared<DummyTask>();
301+
ppc::core::Perf<int, int> perf(task_ptr);
302+
EXPECT_THROW(perf.PrintPerfStatistic("test"), std::runtime_error);
303+
}
304+
EXPECT_TRUE(ppc::util::DestructorFailureFlag::Get());
305+
ppc::util::DestructorFailureFlag::Unset();
306+
}
307+
308+
TEST(PerfTest, GetStringParamNameTest) {
309+
EXPECT_EQ(GetStringParamName(ppc::core::PerfResults::kTaskRun), "task_run");
310+
EXPECT_EQ(GetStringParamName(ppc::core::PerfResults::kPipeline), "pipeline");
311+
EXPECT_EQ(GetStringParamName(ppc::core::PerfResults::kNone), "none");
312+
}
313+
314+
TEST(TaskTest, Destructor_InvalidPipelineOrderTerminates_PartialPipeline) {
315+
{
316+
struct BadTask : ppc::core::Task<int, int> {
317+
bool ValidationImpl() override { return true; }
318+
bool PreProcessingImpl() override { return true; }
319+
bool RunImpl() override { return true; }
320+
bool PostProcessingImpl() override { return true; }
321+
} task;
322+
task.Validation();
323+
}
324+
EXPECT_TRUE(ppc::util::DestructorFailureFlag::Get());
325+
ppc::util::DestructorFailureFlag::Unset();
326+
}

modules/core/runners/include/runners.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,10 @@ class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener {
4343
/// finalization fails.
4444
int Init(int argc, char** argv);
4545

46+
/// @brief Initializes the testing environment only for gtest.
47+
/// @param argc Argument count.
48+
/// @param argv Argument vector.
49+
/// @return Exit code from RUN_ALL_TESTS.
50+
int SimpleInit(int argc, char** argv);
51+
4652
} // namespace ppc::core

modules/core/runners/src/runners.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <format>
88
#include <iostream>
99
#include <memory>
10+
#include <stdexcept>
1011
#include <string>
1112

1213
#include "core/util/include/util.hpp"
@@ -82,6 +83,11 @@ int Init(int argc, char** argv) {
8283
listeners.Append(new ppc::core::UnreadMessagesDetector());
8384
auto status = RUN_ALL_TESTS();
8485

86+
if (ppc::util::DestructorFailureFlag::Get()) {
87+
throw std::runtime_error(
88+
std::format("[ ERROR ] Destructor failed with code {}", ppc::util::DestructorFailureFlag::Get()));
89+
}
90+
8591
const int finalize_res = MPI_Finalize();
8692
if (finalize_res != MPI_SUCCESS) {
8793
std::cerr << std::format("[ ERROR ] MPI_Finalize failed with code {}", finalize_res) << '\n';
@@ -91,4 +97,17 @@ int Init(int argc, char** argv) {
9197
return status;
9298
}
9399

100+
int SimpleInit(int argc, char** argv) {
101+
// Limit the number of threads in TBB
102+
tbb::global_control control(tbb::global_control::max_allowed_parallelism, ppc::util::GetNumThreads());
103+
104+
testing::InitGoogleTest(&argc, argv);
105+
auto status = RUN_ALL_TESTS();
106+
if (ppc::util::DestructorFailureFlag::Get()) {
107+
throw std::runtime_error(
108+
std::format("[ ERROR ] Destructor failed with code {}", ppc::util::DestructorFailureFlag::Get()));
109+
}
110+
return status;
111+
}
112+
94113
} // namespace ppc::core

modules/core/task/include/task.hpp

Lines changed: 67 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <omp.h>
44

55
#include <algorithm>
6+
#include <array>
67
#include <chrono>
78
#include <core/util/include/util.hpp>
89
#include <cstdint>
@@ -16,6 +17,7 @@
1617
#include <sstream>
1718
#include <stdexcept>
1819
#include <string>
20+
#include <utility>
1921
#include <vector>
2022

2123
namespace ppc::core {
@@ -39,6 +41,25 @@ enum TypeOfTask : uint8_t {
3941
kUnknown
4042
};
4143

44+
using TaskMapping = std::pair<TypeOfTask, std::string>;
45+
using TaskMappingArray = std::array<TaskMapping, 6>;
46+
47+
const TaskMappingArray kTaskTypeMappings = {{{TypeOfTask::kALL, "all"},
48+
{TypeOfTask::kMPI, "mpi"},
49+
{TypeOfTask::kOMP, "omp"},
50+
{TypeOfTask::kSEQ, "seq"},
51+
{TypeOfTask::kSTL, "stl"},
52+
{TypeOfTask::kTBB, "tbb"}}};
53+
54+
inline std::string TypeOfTaskToString(TypeOfTask type) {
55+
for (const auto &[key, value] : kTaskTypeMappings) {
56+
if (key == type) {
57+
return value;
58+
}
59+
}
60+
return "unknown";
61+
}
62+
4263
/// @brief Indicates whether a task is enabled or disabled.
4364
enum StatusOfTask : uint8_t {
4465
/// Task is enabled and should be executed
@@ -71,29 +92,12 @@ inline std::string GetStringTaskType(TypeOfTask type_of_task, const std::string
7192
auto list_settings = ppc::util::InitJSONPtr();
7293
file >> *list_settings;
7394

74-
auto to_type_str = [&](const std::string &type) -> std::string {
75-
return type + "_" + std::string((*list_settings)["tasks"][type]);
76-
};
77-
78-
if (type_of_task == TypeOfTask::kALL) {
79-
return to_type_str("all");
80-
}
81-
if (type_of_task == TypeOfTask::kSTL) {
82-
return to_type_str("stl");
95+
std::string type_str = TypeOfTaskToString(type_of_task);
96+
if (type_str == "unknown") {
97+
return type_str;
8398
}
84-
if (type_of_task == TypeOfTask::kOMP) {
85-
return to_type_str("omp");
86-
}
87-
if (type_of_task == TypeOfTask::kMPI) {
88-
return to_type_str("mpi");
89-
}
90-
if (type_of_task == TypeOfTask::kTBB) {
91-
return to_type_str("tbb");
92-
}
93-
if (type_of_task == TypeOfTask::kSEQ) {
94-
return to_type_str("seq");
95-
}
96-
return "unknown";
99+
100+
return type_str + "_" + std::string((*list_settings)["tasks"][type_str]);
97101
}
98102

99103
enum StateOfTesting : uint8_t { kFunc, kPerf };
@@ -104,39 +108,56 @@ template <typename InType, typename OutType>
104108
/// @tparam OutType Output data type.
105109
class Task {
106110
public:
107-
/// @brief Constructs a new Task object.
108-
explicit Task(StateOfTesting /*state_of_testing*/ = StateOfTesting::kFunc) { functions_order_.clear(); }
109-
110111
/// @brief Validates input data and task attributes before execution.
111112
/// @return True if validation is successful.
112113
virtual bool Validation() final {
113-
InternalOrderTest(ppc::util::FuncName());
114+
if (stage_ == PipelineStage::kNone || stage_ == PipelineStage::kDone) {
115+
stage_ = PipelineStage::kValidation;
116+
} else {
117+
stage_ = PipelineStage::kException;
118+
throw std::runtime_error("Validation should be called before preprocessing");
119+
}
114120
return ValidationImpl();
115121
}
116122

117123
/// @brief Performs preprocessing on the input data.
118124
/// @return True if preprocessing is successful.
119125
virtual bool PreProcessing() final {
120-
InternalOrderTest(ppc::util::FuncName());
126+
if (stage_ == PipelineStage::kValidation) {
127+
stage_ = PipelineStage::kPreProcessing;
128+
} else {
129+
stage_ = PipelineStage::kException;
130+
throw std::runtime_error("Preprocessing should be called after validation");
131+
}
121132
if (state_of_testing_ == StateOfTesting::kFunc) {
122-
InternalTimeTest(ppc::util::FuncName());
133+
InternalTimeTest();
123134
}
124135
return PreProcessingImpl();
125136
}
126137

127138
/// @brief Executes the main logic of the task.
128139
/// @return True if execution is successful.
129140
virtual bool Run() final {
130-
InternalOrderTest(ppc::util::FuncName());
141+
if (stage_ == PipelineStage::kPreProcessing || stage_ == PipelineStage::kRun) {
142+
stage_ = PipelineStage::kRun;
143+
} else {
144+
stage_ = PipelineStage::kException;
145+
throw std::runtime_error("Run should be called after preprocessing");
146+
}
131147
return RunImpl();
132148
}
133149

134150
/// @brief Performs postprocessing on the output data.
135151
/// @return True if postprocessing is successful.
136152
virtual bool PostProcessing() final {
137-
InternalOrderTest(ppc::util::FuncName());
153+
if (stage_ == PipelineStage::kRun) {
154+
stage_ = PipelineStage::kDone;
155+
} else {
156+
stage_ = PipelineStage::kException;
157+
throw std::runtime_error("Postprocessing should be called after run");
158+
}
138159
if (state_of_testing_ == StateOfTesting::kFunc) {
139-
InternalTimeTest(ppc::util::FuncName());
160+
InternalTimeTest();
140161
}
141162
return PostProcessingImpl();
142163
}
@@ -170,41 +191,25 @@ class Task {
170191
OutType &GetOutput() { return output_; }
171192

172193
/// @brief Destructor. Verifies that the pipeline was executed in the correct order.
173-
/// @note Terminates the program if pipeline order is incorrect or incomplete.
194+
/// @note Terminates the program if the pipeline order is incorrect or incomplete.
174195
virtual ~Task() {
175-
if (!functions_order_.empty() || !was_worked_) {
176-
std::cerr << "ORDER OF FUNCTIONS IS NOT RIGHT! \n Expected - \"Validation\", \"PreProcessing\", \"Run\", "
177-
"\"PostProcessing\" \n";
178-
std::terminate();
179-
} else {
180-
functions_order_.clear();
196+
if (stage_ != PipelineStage::kDone && stage_ != PipelineStage::kException) {
197+
ppc::util::DestructorFailureFlag::Set();
181198
}
182199
#if _OPENMP >= 201811
183200
omp_pause_resource_all(omp_pause_soft);
184201
#endif
185202
}
186203

187204
protected:
188-
/// @brief Verifies the correct order of pipeline method calls.
189-
/// @param str Name of the method being called.
190-
virtual void InternalOrderTest(const std::string &str) final {
191-
functions_order_.push_back(str);
192-
if (str == "PostProcessing" && IsFullPipelineStage()) {
193-
functions_order_.clear();
194-
} else {
195-
was_worked_ = true;
196-
}
197-
}
198-
199205
/// @brief Measures execution time between preprocessing and postprocessing steps.
200-
/// @param str Name of the method being timed.
201206
/// @throws std::runtime_error If execution exceeds the allowed time limit.
202-
virtual void InternalTimeTest(const std::string &str) final {
203-
if (str == "PreProcessing") {
207+
virtual void InternalTimeTest() final {
208+
if (stage_ == PipelineStage::kPreProcessing) {
204209
tmp_time_point_ = std::chrono::high_resolution_clock::now();
205210
}
206211

207-
if (str == "PostProcessing") {
212+
if (stage_ == PipelineStage::kDone) {
208213
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::high_resolution_clock::now() -
209214
tmp_time_point_)
210215
.count();
@@ -244,26 +249,16 @@ class Task {
244249
StateOfTesting state_of_testing_ = kFunc;
245250
TypeOfTask type_of_task_ = kUnknown;
246251
StatusOfTask status_of_task_ = kEnabled;
247-
std::vector<std::string> functions_order_;
248-
std::vector<std::string> right_functions_order_ = {"Validation", "PreProcessing", "Run", "PostProcessing"};
249252
static constexpr double kMaxTestTime = 1.0;
250253
std::chrono::high_resolution_clock::time_point tmp_time_point_;
251-
bool was_worked_ = false;
252-
253-
bool IsFullPipelineStage() {
254-
if (functions_order_.size() < 4) {
255-
return false;
256-
}
257-
258-
auto it = std::adjacent_find(functions_order_.begin() + 2,
259-
functions_order_.begin() + static_cast<long>(functions_order_.size() - 2),
260-
std::not_equal_to<>());
261-
262-
return (functions_order_[0] == "Validation" && functions_order_[1] == "PreProcessing" &&
263-
functions_order_[2] == "Run" &&
264-
it == (functions_order_.begin() + static_cast<long>(functions_order_.size() - 2)) &&
265-
functions_order_[functions_order_.size() - 1] == "PostProcessing");
266-
}
254+
enum class PipelineStage : uint8_t {
255+
kNone,
256+
kValidation,
257+
kPreProcessing,
258+
kRun,
259+
kDone,
260+
kException
261+
} stage_ = PipelineStage::kNone;
267262
};
268263

269264
/// @brief Smart pointer alias for Task.
@@ -276,7 +271,7 @@ using TaskPtr = std::shared_ptr<Task<InType, OutType>>;
276271
/// @tparam TaskType Type of the task to create.
277272
/// @tparam InType Type of the input.
278273
/// @param in Input to pass to the task constructor.
279-
/// @return Shared pointer to the newly created task.
274+
/// @return Shared a pointer to the newly created task.
280275
template <typename TaskType, typename InType>
281276
std::shared_ptr<TaskType> TaskGetter(InType in) {
282277
return std::make_shared<TaskType>(in);

0 commit comments

Comments
 (0)