Skip to content

Commit da8059a

Browse files
Close connection properly from transport closing (#17)
1 parent 8b08788 commit da8059a

File tree

11 files changed

+322
-78
lines changed

11 files changed

+322
-78
lines changed

samples/HubConnectionSample/HubConnectionSample.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class logger : public signalr::log_writer
1313
// Inherited via log_writer
1414
virtual void __cdecl write(const std::string & entry) override
1515
{
16-
std::cout << entry << std::endl;
16+
std::cout << entry;
1717
}
1818
};
1919

@@ -77,12 +77,12 @@ void chat()
7777
}
7878

7979
std::cout << "Enter your message:";
80-
for (;;)
80+
while (connection.get_connection_state() == signalr::connection_state::connected)
8181
{
8282
std::string message;
8383
std::getline(std::cin, message);
8484

85-
if (message == ":q")
85+
if (message == ":q" || connection.get_connection_state() != signalr::connection_state::connected)
8686
{
8787
break;
8888
}

src/signalrclient/connection_impl.cpp

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,7 @@ namespace signalr
145145
m_connection_id = "";
146146
}
147147

148-
start_negotiate(m_base_url, 0, [callback](std::exception_ptr exception)
149-
{
150-
callback(exception);
151-
});
148+
start_negotiate(m_base_url, 0, callback);
152149
}
153150

154151
void connection_impl::start_negotiate(const std::string& url, int redirect_count, std::function<void(std::exception_ptr)> callback)
@@ -311,6 +308,21 @@ namespace signalr
311308
auto transport = connection->m_transport_factory->create_transport(
312309
transport_type::websockets, connection->m_logger, connection->m_signalr_client_config);
313310

311+
transport->on_close([weak_connection](std::exception_ptr exception)
312+
{
313+
auto connection = weak_connection.lock();
314+
if (!connection)
315+
{
316+
return;
317+
}
318+
319+
// close callback will only be called if start on the transport has already returned
320+
// wait for the event in order to avoid a race where the state hasn't changed from connecting
321+
// yet and the transport errors out
322+
connection->m_start_completed_event.wait();
323+
connection->stop_connection(exception);
324+
});
325+
314326
transport->on_receive([disconnect_cts, connect_request_done, connect_request_lock, logger, weak_connection, callback](std::string&& message, std::exception_ptr exception)
315327
{
316328
if (exception == nullptr)
@@ -532,50 +544,7 @@ namespace signalr
532544
{
533545
m_logger.log(trace_level::info, "stopping connection");
534546

535-
auto connection = shared_from_this();
536-
shutdown([connection, callback](std::exception_ptr exception)
537-
{
538-
std::thread([connection, callback, exception]()
539-
{
540-
if (exception != nullptr)
541-
{
542-
callback(exception);
543-
return;
544-
}
545-
546-
{
547-
// the lock prevents a race where the user calls `stop` on a disconnected connection and calls `start`
548-
// on a different thread at the same time. In this case we must not null out the transport if we are
549-
// not in the `disconnecting` state to not affect the 'start' invocation.
550-
std::lock_guard<std::mutex> lock(connection->m_stop_lock);
551-
if (connection->change_state(connection_state::disconnecting, connection_state::disconnected))
552-
{
553-
// we do let the exception through (especially the task_canceled exception)
554-
connection->m_transport = nullptr;
555-
}
556-
}
557-
558-
try
559-
{
560-
connection->m_disconnected();
561-
}
562-
catch (const std::exception& e)
563-
{
564-
connection->m_logger.log(
565-
trace_level::errors,
566-
std::string("disconnected callback threw an exception: ")
567-
.append(e.what()));
568-
}
569-
catch (...)
570-
{
571-
connection->m_logger.log(
572-
trace_level::errors,
573-
std::string("disconnected callback threw an unknown exception"));
574-
}
575-
576-
callback(nullptr);
577-
}).detach();
578-
});
547+
shutdown(callback);
579548
}
580549

581550
// This function is called from the dtor so you must not use `shared_from_this` here (it will throw).
@@ -623,10 +592,61 @@ namespace signalr
623592
change_state(connection_state::disconnecting);
624593
}
625594

626-
m_transport->stop([callback](std::exception_ptr exception)
595+
m_transport->stop(callback);
596+
}
597+
598+
// do not use `shared_from_this` as it can be called via the destructor
599+
void connection_impl::stop_connection(std::exception_ptr error)
600+
{
601+
{
602+
// the lock prevents a race where the user calls `stop` on a disconnected connection and calls `start`
603+
// on a different thread at the same time. In this case we must not null out the transport if we are
604+
// not in the `disconnecting` state to not affect the 'start' invocation.
605+
std::lock_guard<std::mutex> lock(m_stop_lock);
606+
607+
if (m_connection_state == connection_state::disconnected)
627608
{
628-
callback(exception);
629-
});
609+
m_logger.log(trace_level::info, "Stopping was ignored because the connection is already in the disconnected state.");
610+
return;
611+
}
612+
613+
change_state(connection_state::disconnected);
614+
m_transport = nullptr;
615+
}
616+
617+
if (error)
618+
{
619+
try
620+
{
621+
std::rethrow_exception(error);
622+
}
623+
catch (const std::exception & ex)
624+
{
625+
m_logger.log(trace_level::errors, std::string("Connection closed with error: ").append(ex.what()));
626+
}
627+
}
628+
else
629+
{
630+
m_logger.log(trace_level::info, "Connection closed.");
631+
}
632+
633+
try
634+
{
635+
m_disconnected();
636+
}
637+
catch (const std::exception & e)
638+
{
639+
m_logger.log(
640+
trace_level::errors,
641+
std::string("disconnected callback threw an exception: ")
642+
.append(e.what()));
643+
}
644+
catch (...)
645+
{
646+
m_logger.log(
647+
trace_level::errors,
648+
std::string("disconnected callback threw an unknown exception"));
649+
}
630650
}
631651

632652
connection_state connection_impl::get_connection_state() const noexcept

src/signalrclient/connection_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ namespace signalr
8080
void process_response(std::string&& response);
8181

8282
void shutdown(std::function<void(std::exception_ptr)> callback);
83+
void stop_connection(std::exception_ptr);
8384

8485
bool change_state(connection_state old_state, connection_state new_state);
8586
connection_state change_state(connection_state new_state);

src/signalrclient/hub_connection_impl.cpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ namespace signalr
6161
auto connection = weak_hub_connection.lock();
6262
if (connection)
6363
{
64+
// start may be waiting on the handshake response so we complete it here, this no-ops if already set
6465
connection->m_handshakeTask->set(std::make_exception_ptr(signalr_exception("connection closed while handshake was in progress.")));
66+
67+
connection->m_callback_manager.clear(signalr::value(std::map<std::string, signalr::value> { { std::string("error"), std::string("connection was stopped before invocation result was received") } }));
68+
6569
connection->m_disconnected();
6670
}
6771
});
@@ -125,7 +129,6 @@ namespace signalr
125129
callback(std::make_exception_ptr(signalr_exception("the hub connection has been deconstructed")));
126130
return;
127131
}
128-
connection->m_handshakeTask->get();
129132
}
130133
catch (...) {}
131134

@@ -171,8 +174,51 @@ namespace signalr
171174

172175
void hub_connection_impl::stop(std::function<void(std::exception_ptr)> callback) noexcept
173176
{
174-
m_callback_manager.clear(signalr::value(std::map<std::string, signalr::value> { { std::string("error"), std::string("connection was stopped before invocation result was received") } }));
175-
m_connection->stop(callback);
177+
if (get_connection_state() == connection_state::disconnected)
178+
{
179+
m_logger.log(trace_level::info, "Stop ignored because the connection is already disconnected.");
180+
callback(nullptr);
181+
return;
182+
}
183+
else
184+
{
185+
{
186+
std::lock_guard<std::mutex> lock(m_stop_callback_lock);
187+
m_stop_callbacks.push_back(callback);
188+
189+
if (m_stop_callbacks.size() > 1)
190+
{
191+
m_logger.log(trace_level::info, "Stop is already in progress, waiting for it to finish.");
192+
// we already registered the callback
193+
// so we can just return now as the in-progress stop will trigger the callback when it completes
194+
return;
195+
}
196+
}
197+
std::weak_ptr<hub_connection_impl> weak_connection = shared_from_this();
198+
m_connection->stop([weak_connection](std::exception_ptr exception)
199+
{
200+
auto connection = weak_connection.lock();
201+
if (!connection)
202+
{
203+
return;
204+
}
205+
206+
std::vector<std::function<void(std::exception_ptr)>> callbacks;
207+
208+
{
209+
std::lock_guard<std::mutex> lock(connection->m_stop_callback_lock);
210+
// copy the callbacks out and clear the list inside the lock
211+
// then run the callbacks outside of the lock
212+
callbacks = connection->m_stop_callbacks;
213+
connection->m_stop_callbacks.clear();
214+
}
215+
216+
for (auto& callback : callbacks)
217+
{
218+
callback(exception);
219+
}
220+
});
221+
}
176222
}
177223

178224
void hub_connection_impl::process_message(std::string&& response)
@@ -184,11 +230,11 @@ namespace signalr
184230
signalr::value handshake;
185231
std::tie(response, handshake) = handshake::parse_handshake(response);
186232

187-
auto obj = handshake.as_map();
233+
auto& obj = handshake.as_map();
188234
auto found = obj.find("error");
189235
if (found != obj.end())
190236
{
191-
auto error = found->second.as_string();
237+
auto& error = found->second.as_string();
192238
m_logger.log(trace_level::errors, std::string("handshake error: ")
193239
.append(error));
194240
m_handshakeTask->set(std::make_exception_ptr(signalr_exception(std::string("Received an error during handshake: ").append(error))));
@@ -291,13 +337,13 @@ namespace signalr
291337
throw signalr_exception("expected object");
292338
}
293339

294-
auto invocationId = message.as_map().at("invocationId");
340+
auto& invocationId = message.as_map().at("invocationId");
295341
if (!invocationId.is_string())
296342
{
297343
throw signalr_exception("invocationId is not a string");
298344
}
299345

300-
auto id = invocationId.as_string();
346+
auto& id = invocationId.as_string();
301347
if (!m_callback_manager.invoke_callback(id, message, true))
302348
{
303349
m_logger.log(trace_level::info, std::string("no callback found for id: ").append(id));

src/signalrclient/hub_connection_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ namespace signalr
5959
signalr_client_config m_signalr_client_config;
6060
std::shared_ptr<hub_protocol> m_protocol;
6161

62+
std::mutex m_stop_callback_lock;
63+
std::vector<std::function<void(std::exception_ptr)>> m_stop_callbacks;
64+
6265
void initialize();
6366

6467
void process_message(std::string&& message);

src/signalrclient/hub_protocol.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ namespace signalr
1414
virtual std::vector<signalr::value> parse_messages(const std::string&) const = 0;
1515
virtual const std::string& name() const = 0;
1616
virtual int version() const = 0;
17+
virtual ~hub_protocol() {}
1718
};
1819
}

src/signalrclient/json_hub_protocol.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace signalr
2121
{
2222
return 1;
2323
}
24+
25+
~json_hub_protocol() {}
2426
private:
2527
signalr::value parse_message(const std::string&) const;
2628

src/signalrclient/websocket_transport.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ namespace signalr
226226
{
227227
try
228228
{
229+
close_callback(exception);
229230
if (exception != nullptr)
230231
{
231232
std::rethrow_exception(exception);
@@ -241,8 +242,6 @@ namespace signalr
241242

242243
callback(exception);
243244
}
244-
245-
close_callback(exception);
246245
});
247246
}
248247

test/signalrclienttests/connection_impl_tests.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1640,10 +1640,34 @@ TEST(connection_impl_stop, exception_for_disconnected_callback_caught_and_logged
16401640
mre.get();
16411641

16421642
auto log_entries = std::dynamic_pointer_cast<memory_log_writer>(writer)->get_log_entries();
1643-
ASSERT_EQ(1U, log_entries.size());
1643+
ASSERT_EQ(1U, log_entries.size()) << dump_vector(log_entries);
16441644
ASSERT_EQ("[error ] disconnected callback threw an unknown exception\n", remove_date_from_log_entry(log_entries[0]));
16451645
}
16461646

1647+
TEST(connection_impl_stop, transport_error_invokes_disconnected_callback)
1648+
{
1649+
auto websocket_client = create_test_websocket_client();
1650+
auto connection = create_connection(websocket_client);
1651+
1652+
auto disconnect_mre = manual_reset_event<void>();
1653+
connection->set_disconnected([&disconnect_mre]() { disconnect_mre.set(); });
1654+
1655+
auto mre = manual_reset_event<void>();
1656+
connection->start([&mre](std::exception_ptr exception)
1657+
{
1658+
mre.set(exception);
1659+
});
1660+
1661+
mre.get();
1662+
1663+
ASSERT_FALSE(websocket_client->receive_loop_started.wait(1000));
1664+
websocket_client->receive_message(std::make_exception_ptr(std::runtime_error("error")));
1665+
1666+
disconnect_mre.get();
1667+
1668+
ASSERT_EQ(connection_state::disconnected, connection->get_connection_state());
1669+
}
1670+
16471671
TEST(connection_impl_config, custom_headers_set_in_requests)
16481672
{
16491673
auto writer = std::shared_ptr<log_writer>{ std::make_shared<memory_log_writer>() };

0 commit comments

Comments
 (0)