33// See the LICENSE file in the project root for more information.
44
55#include " stdafx.h"
6- #include < thread>
76#include < algorithm>
87#include " constants.h"
98#include " connection_impl.h"
1716#include < assert.h>
1817#include " signalrclient/websocket_client.h"
1918#include " default_websocket_client.h"
19+ #include " signalr_default_scheduler.h"
2020
2121namespace signalr
2222{
@@ -26,42 +26,25 @@ namespace signalr
2626 }
2727
2828 std::shared_ptr<connection_impl> connection_impl::create (const std::string& url, trace_level trace_level, const std::shared_ptr<log_writer>& log_writer,
29- std::shared_ptr<http_client> http_client , std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
29+ std::function<std:: shared_ptr<http_client>( const signalr_client_config&)> http_client_factory , std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
3030 {
3131 return std::shared_ptr<connection_impl>(new connection_impl (url, trace_level,
32- log_writer ? log_writer : std::make_shared<trace_log_writer>(), http_client , websocket_factory, skip_negotiation));
32+ log_writer ? log_writer : std::make_shared<trace_log_writer>(), http_client_factory , websocket_factory, skip_negotiation));
3333 }
3434
3535 connection_impl::connection_impl (const std::string& url, trace_level trace_level, const std::shared_ptr<log_writer>& log_writer,
36- std::unique_ptr<http_client> http_client, std::unique_ptr<transport_factory> transport_factory, const bool skip_negotiation)
37- : m_base_url(url), m_connection_state(connection_state::disconnected), m_logger(log_writer, trace_level), m_transport(nullptr ),
38- m_transport_factory (std::move(transport_factory)), m_skip_negotiation(skip_negotiation), m_message_received([](const std::string&) noexcept {}), m_disconnected([]() noexcept {})
39- {
40- if (http_client != nullptr )
41- {
42- m_http_client = std::move (http_client);
43- }
44- else
45- {
46- #ifdef USE_CPPRESTSDK
47- m_http_client = std::unique_ptr<class http_client >(new default_http_client ());
48- #endif
49- }
50- }
51-
52- connection_impl::connection_impl (const std::string& url, trace_level trace_level, const std::shared_ptr<log_writer>& log_writer,
53- std::shared_ptr<http_client> http_client, std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
36+ std::function<std::shared_ptr<http_client>(const signalr_client_config&)> http_client_factory, std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
5437 : m_base_url(url), m_connection_state(connection_state::disconnected), m_logger(log_writer, trace_level), m_transport(nullptr ), m_skip_negotiation(skip_negotiation),
55- m_message_received ([](const std::string&) noexcept {}), m_disconnected([]() noexcept {})
38+ m_message_received ([](const std::string&) noexcept {}), m_disconnected([]() noexcept {}), m_disconnect_cts(std::make_shared<cancellation_token>())
5639 {
57- if (http_client != nullptr )
40+ if (http_client_factory != nullptr )
5841 {
59- m_http_client = std::move (http_client );
42+ m_http_client_factory = std::move (http_client_factory );
6043 }
6144 else
6245 {
6346#ifdef USE_CPPRESTSDK
64- m_http_client = std::unique_ptr<class http_client >(new default_http_client ());
47+ m_http_client_factory = []( const signalr_client_config&) { return std::unique_ptr<class http_client >(new default_http_client ()); } ;
6548#endif
6649 }
6750
@@ -72,7 +55,7 @@ namespace signalr
7255#endif
7356 }
7457
75- m_transport_factory = std::unique_ptr<transport_factory>(new transport_factory (m_http_client , websocket_factory));
58+ m_transport_factory = std::unique_ptr<transport_factory>(new transport_factory (m_http_client_factory , websocket_factory));
7659 }
7760
7861 connection_impl::~connection_impl ()
@@ -138,11 +121,18 @@ namespace signalr
138121 // there should not be any active transport at this point
139122 assert (!m_transport);
140123
141- m_disconnect_cts = std::make_shared<cancellation_token> ();
124+ m_disconnect_cts-> reset ();
142125 m_start_completed_event.reset ();
143126 m_connection_id = " " ;
144127 }
145128
129+ m_scheduler = m_signalr_client_config.get_scheduler ();
130+ if (!m_scheduler)
131+ {
132+ m_scheduler = std::make_shared<signalr_default_scheduler>();
133+ m_signalr_client_config.set_scheduler (m_scheduler);
134+ }
135+
146136 start_negotiate (m_base_url, 0 , callback);
147137 }
148138
@@ -157,7 +147,7 @@ namespace signalr
157147 }
158148
159149 std::weak_ptr<connection_impl> weak_connection = shared_from_this ();
160- const auto & token = m_disconnect_cts;
150+ const auto token = m_disconnect_cts;
161151
162152 const auto transport_started = [weak_connection, callback, token](std::shared_ptr<transport> transport, std::exception_ptr exception)
163153 {
@@ -225,7 +215,8 @@ namespace signalr
225215 return start_transport (url, transport_started);
226216 }
227217
228- negotiate::negotiate (*m_http_client, url, m_signalr_client_config,
218+ auto http_client = m_http_client_factory (m_signalr_client_config);
219+ negotiate::negotiate (http_client, url, m_signalr_client_config,
229220 [callback, weak_connection, redirect_count, token, url, transport_started](negotiation_response&& response, std::exception_ptr exception)
230221 {
231222 auto connection = weak_connection.lock ();
@@ -320,7 +311,7 @@ namespace signalr
320311 std::shared_ptr<std::mutex> connect_request_lock = std::make_shared<std::mutex>();
321312
322313 auto weak_connection = std::weak_ptr<connection_impl>(connection);
323- const auto & disconnect_cts = m_disconnect_cts;
314+ const auto disconnect_cts = m_disconnect_cts;
324315 const auto & logger = m_logger;
325316
326317 auto transport = connection->m_transport_factory ->create_transport (
@@ -406,39 +397,51 @@ namespace signalr
406397 }
407398 });
408399
409- std::thread ([disconnect_cts, connect_request_done, connect_request_lock, callback, weak_connection]()
410- {
411- disconnect_cts->wait (5000 );
400+ disconnect_cts->register_callback ([connect_request_done, connect_request_lock, callback]()
401+ {
402+ bool run_callback = false ;
403+ {
404+ std::lock_guard<std::mutex> lock (*connect_request_lock);
412405
406+ // no op after connection started successfully
407+ if (*connect_request_done == false )
408+ {
409+ *connect_request_done = true ;
410+ run_callback = true ;
411+ }
412+ } // unlock
413+
414+ if (run_callback)
415+ {
416+ // The callback checks the disconnect_cts token and will handle it appropriately
417+ callback ({}, nullptr );
418+ }
419+ });
420+
421+ timer (m_scheduler, [connect_request_done, connect_request_lock, callback](std::chrono::milliseconds duration) {
413422 bool run_callback = false ;
414423 {
415424 std::lock_guard<std::mutex> lock (*connect_request_lock);
425+
416426 // no op after connection started successfully
417427 if (*connect_request_done == false )
418428 {
429+ if (duration < std::chrono::seconds (5 ))
430+ {
431+ return false ;
432+ }
419433 *connect_request_done = true ;
420434 run_callback = true ;
421435 }
422- }
436+ } // unlock
423437
424- // if the disconnect_cts is canceled it means that the connection has been stopped or went out of scope in
425- // which case we should not throw due to timeout.
426- if (disconnect_cts->is_canceled ())
438+ if (run_callback)
427439 {
428- if (run_callback)
429- {
430- // The callback checks the disconnect_cts token and will handle it appropriately
431- callback ({}, nullptr );
432- }
433- }
434- else
435- {
436- if (run_callback)
437- {
438- callback ({}, std::make_exception_ptr (signalr_exception (" transport timed out when trying to connect" )));
439- }
440+ callback ({}, std::make_exception_ptr (signalr_exception (" transport timed out when trying to connect" )));
440441 }
441- }).detach ();
442+
443+ return true ;
444+ });
442445
443446 connection->send_connect_request (transport, url, [callback, connect_request_done, connect_request_lock, transport](std::exception_ptr exception)
444447 {
@@ -597,6 +600,7 @@ namespace signalr
597600 const auto current_state = get_connection_state ();
598601 if (current_state == connection_state::disconnected)
599602 {
603+ m_disconnect_cts->cancel ();
600604 callback (nullptr );
601605 return ;
602606 }
0 commit comments