44#pragma once
55
66#include < thread>
7+ #include < variant>
78
89#include " synchronized_queue.hpp"
910
@@ -33,7 +34,7 @@ class ThreadedCallbackWrapper {
3334 return CallbackStatus::STOP;
3435 }
3536
36- m_squeue.push ({ step, num_steps, latent} );
37+ m_squeue.push (std::make_tuple ( step, num_steps, latent) );
3738
3839 return CallbackStatus::RUNNING;
3940 }
@@ -44,7 +45,7 @@ class ThreadedCallbackWrapper {
4445 }
4546
4647 m_status = CallbackStatus::STOP;
47- m_squeue.empty ( );
48+ m_squeue.push ( std::monostate () );
4849
4950 if (m_worker_thread && m_worker_thread->joinable ()) {
5051 m_worker_thread->join ();
@@ -58,18 +59,23 @@ class ThreadedCallbackWrapper {
5859private:
5960 std::function<bool (size_t , size_t , ov::Tensor&)> m_callback = nullptr ;
6061 std::shared_ptr<std::thread> m_worker_thread = nullptr ;
61- SynchronizedQueue<std::tuple<size_t , size_t , ov::Tensor>> m_squeue;
62+ SynchronizedQueue<std::variant<std:: tuple<size_t , size_t , ov::Tensor>, std::monostate >> m_squeue;
6263
6364 std::atomic<CallbackStatus> m_status = CallbackStatus::RUNNING;
6465
6566 void _worker () {
6667 while (m_status == CallbackStatus::RUNNING) {
67- // wait for queue pull
68- auto [step, num_steps, latent] = m_squeue.pull ();
69-
70- if (m_callback (step, num_steps, latent)) {
71- m_status = CallbackStatus::STOP;
72- m_squeue.empty ();
68+ auto item = m_squeue.pull ();
69+
70+ if (auto callback_data = std::get_if<std::tuple<size_t , size_t , ov::Tensor>>(&item)) {
71+ auto & [step, num_steps, latent] = *callback_data;
72+ const auto should_stop = m_callback (step, num_steps, latent);
73+
74+ if (should_stop) {
75+ m_status = CallbackStatus::STOP;
76+ }
77+ } else if (std::get_if<std::monostate>(&item)) {
78+ break ;
7379 }
7480 }
7581 }
0 commit comments