diff --git a/Rx/v2/src/rxcpp/subjects/rx-subject.hpp b/Rx/v2/src/rxcpp/subjects/rx-subject.hpp index 4ce96b4733..ca611ac104 100644 --- a/Rx/v2/src/rxcpp/subjects/rx-subject.hpp +++ b/Rx/v2/src/rxcpp/subjects/rx-subject.hpp @@ -38,7 +38,7 @@ class multicast_observer , lifetime(cs) { } - std::mutex lock; + std::recursive_mutex lock; typename mode::type current; rxu::error_ptr error; composite_subscription lifetime; @@ -125,13 +125,13 @@ class multicast_observer return make_subscriber(get_id(), get_subscription(), observer>(*this)); } bool has_observers() const { - std::unique_lock guard(b->state->lock); + std::unique_lock guard(b->state->lock); return b->completer && !b->completer->observers.empty(); } template void add(const SubscriberFrom& sf, observer_type o) const { trace_activity().connect(sf, o); - std::unique_lock guard(b->state->lock); + std::unique_lock guard(b->state->lock); switch (b->state->current) { case mode::Casting: { @@ -140,7 +140,7 @@ class multicast_observer o.add([=](){ auto b = binder.lock(); if (b) { - std::unique_lock guard(b->state->lock); + std::unique_lock guard(b->state->lock); b->completer = std::make_shared(b->state, b->completer); } }); @@ -178,7 +178,7 @@ class multicast_observer void on_next(V v) const { auto current_completer = b->current_completer.lock(); if (!current_completer) { - std::unique_lock guard(b->state->lock); + std::unique_lock guard(b->state->lock); b->current_completer = b->completer; current_completer = b->current_completer.lock(); } @@ -192,7 +192,7 @@ class multicast_observer } } void on_error(rxu::error_ptr e) const { - std::unique_lock guard(b->state->lock); + std::unique_lock guard(b->state->lock); if (b->state->current == mode::Casting) { b->state->error = e; b->state->current = mode::Errored; @@ -211,7 +211,7 @@ class multicast_observer } } void on_completed() const { - std::unique_lock guard(b->state->lock); + std::unique_lock guard(b->state->lock); if (b->state->current == mode::Casting) { b->state->current = mode::Completed; auto s = b->state->lifetime; diff --git a/Rx/v2/test/CMakeLists.txt b/Rx/v2/test/CMakeLists.txt index c2d1530692..8d1fd1247b 100644 --- a/Rx/v2/test/CMakeLists.txt +++ b/Rx/v2/test/CMakeLists.txt @@ -21,6 +21,7 @@ set(TEST_DIR ${RXCPP_DIR}/Rx/v2/test) set(TEST_SOURCES ${TEST_DIR}/subscriptions/coroutine.cpp ${TEST_DIR}/subscriptions/observer.cpp + ${TEST_DIR}/subscriptions/race_condition.cpp ${TEST_DIR}/subscriptions/subscription.cpp ${TEST_DIR}/subjects/subject.cpp ${TEST_DIR}/sources/create.cpp diff --git a/Rx/v2/test/subscriptions/race_condition.cpp b/Rx/v2/test/subscriptions/race_condition.cpp new file mode 100644 index 0000000000..272f6b982c --- /dev/null +++ b/Rx/v2/test/subscriptions/race_condition.cpp @@ -0,0 +1,31 @@ +#include "../test.h" +#include "rxcpp/rx.hpp" +#include "rxcpp/operators/rx-observe_on.hpp" +#include "rxcpp/operators/rx-merge.hpp" +#include "rxcpp/rx-scheduler.hpp" + +SCENARIO("multicast_observer race condition") { + + // We loop this test many many times because it is attempting to trigger a + // race condition that is not guaranteed to occur, described in + // https://github.com/ReactiveX/RxCpp/issues/555 + for (std::size_t i=0; i < 5000; ++i) { + auto comp1 = rxcpp::composite_subscription(); + auto mco = rxcpp::subjects::detail::multicast_observer(comp1); + + auto comp2 = rxcpp::composite_subscription(); + auto obs = rxcpp::observer(); + auto sub = rxcpp::subscriber( + rxcpp::trace_id::make_next_id_subscriber(), + comp2, + obs); + + using namespace std::chrono_literals; + auto t = std::thread([&](){ + comp2.unsubscribe(); + }); + + mco.add(mco.get_subscription(), sub); + t.join(); + } +}