Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gloo/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(GLOO_COMMON_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/logging.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/utils.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
)

set(GLOO_COMMON_HDRS
Expand Down
46 changes: 46 additions & 0 deletions gloo/common/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <atomic>
#include <list>

#include "gloo/common/error.h"

namespace gloo {


std::list<std::condition_variable *> _cvs;
std::mutex _cvs_mutex;

std::atomic_bool _is_aborted_flag(false);

bool _is_aborted() {
return _is_aborted_flag.load();
}

void abort() {
_is_aborted_flag.store(true);
std::lock_guard<std::mutex> guard(_cvs_mutex);
for(auto& cv : _cvs) {
if(cv != NULL) {
cv->notify_all();
}
}
GLOO_THROW("GLOO ABORTED");
}

void _register_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.push_back(cv);
}

void _deregister_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.remove(cv);
}
} // namespace gloo
6 changes: 6 additions & 0 deletions gloo/common/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <chrono>
#include <exception>
#include <condition_variable>

#include "gloo/common/string.h"

Expand All @@ -20,6 +21,11 @@ namespace gloo {

const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero();

bool _is_aborted();
void abort();
void _register_cv(std::condition_variable *cv);
void _deregister_cv(std::condition_variable *cv);

// A base class for all gloo runtime errors
struct Exception : public std::runtime_error {
Exception() = delete;
Expand Down
1 change: 1 addition & 0 deletions gloo/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
find_package(OpenSSL 1.1 REQUIRED EXACT)

set(GLOO_TEST_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc"
Expand Down
145 changes: 145 additions & 0 deletions gloo/test/abort_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <functional>
#include <thread>
#include <vector>

#include "gloo/barrier_all_to_all.h"
#include "gloo/barrier_all_to_one.h"
#include "gloo/broadcast.h"
#include "gloo/test/base_test.h"

namespace gloo {
namespace test {
namespace {

// Function to instantiate and run algorithm.
using Func = void(std::shared_ptr<::gloo::Context>);

// Test parameterization.
using Param = std::tuple<Transport, int, std::function<Func>>;

// Test fixture.
class BarrierTest : public BaseTest,
public ::testing::WithParamInterface<Param> {};

TEST_P(BarrierTest, SinglePointer) {
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());
const auto fn = std::get<2>(GetParam());

spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
fn(context);
});
}

static std::function<Func> barrierAllToAll =
[](std::shared_ptr<::gloo::Context> context) {
::gloo::BarrierAllToAll algorithm(context);
algorithm.run();
};

INSTANTIATE_TEST_CASE_P(
BarrierAllToAll,
BarrierTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForClassAlgorithms),
::testing::Range(2, 16),
::testing::Values(barrierAllToAll)));

static std::function<Func> barrierAllToOne =
[](std::shared_ptr<::gloo::Context> context) {
::gloo::BarrierAllToOne algorithm(context);
algorithm.run();
};

INSTANTIATE_TEST_CASE_P(
BarrierAllToOne,
BarrierTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForClassAlgorithms),
::testing::Range(2, 16),
::testing::Values(barrierAllToOne)));

// Synchronized version of std::chrono::clock::now().
// All processes participating in the specified context will
// see the same value.
template <typename clock>
std::chrono::time_point<clock> syncNow(std::shared_ptr<Context> context) {
const typename clock::time_point now = clock::now();
typename clock::duration::rep count = now.time_since_epoch().count();
BroadcastOptions opts(context);
opts.setRoot(0);
opts.setOutput(&count, 1);
broadcast(opts);
return typename clock::time_point(typename clock::duration(count));
}

using NewParam = std::tuple<Transport, int>;

class BarrierNewTest : public BaseTest,
public ::testing::WithParamInterface<NewParam> {};

TEST_P(BarrierNewTest, Default) {
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());

spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
BarrierOptions opts(context);

// Run barrier to synchronize processes after starting.
barrier(opts);

// Take turns in sleeping for a bit and checking that all processes
// saw that artificial delay through the barrier.
auto singleProcessDelay = std::chrono::milliseconds(1000);
for (size_t i = 0; i < context->size; i++) {
const auto start = syncNow<std::chrono::high_resolution_clock>(context);
if (i == context->rank) {
/* sleep override */
std::this_thread::sleep_for(singleProcessDelay);
}

barrier(opts);
abort();

// Expect all processes to have taken less than the sleep, as abort was called
auto stop = std::chrono::high_resolution_clock::now();
auto delta = std::chrono::duration_cast<decltype(singleProcessDelay)>(
stop - start);
ASSERT_LE(delta.count(), singleProcessDelay.count());
}
});
}

INSTANTIATE_TEST_CASE_P(
BarrierNewDefault,
BarrierNewTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
::testing::Values(1, 2, 4, 7)));

TEST_F(BarrierNewTest, TestTimeout) {
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
BarrierOptions opts(context);
opts.setTimeout(std::chrono::milliseconds(10));
if (context->rank == 0) {
try {
barrier(opts);
FAIL() << "Expected exception to be thrown";
} catch (::gloo::IoException& e) {
ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos);
}
}
});
}

} // namespace
} // namespace test
} // namespace gloo
16 changes: 14 additions & 2 deletions gloo/transport/tcp/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(m_);
Expand Down Expand Up @@ -60,6 +66,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
Expand Down Expand Up @@ -111,6 +120,9 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
Expand Down
26 changes: 20 additions & 6 deletions gloo/transport/uv/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -58,8 +64,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
}

if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(
lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; });
auto done = recvCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down Expand Up @@ -94,8 +104,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
}

if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(
lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; });
auto done = sendCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down