|
| 1 | +package threads |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "os" |
| 6 | + "os/signal" |
| 7 | + "sync" |
| 8 | + "sync/atomic" |
| 9 | + "syscall" |
| 10 | + |
| 11 | + "github.com/openmcp-project/controller-utils/pkg/logging" |
| 12 | +) |
| 13 | + |
| 14 | +var sigs chan os.Signal |
| 15 | + |
| 16 | +func init() { |
| 17 | + sigs = make(chan os.Signal, 1) |
| 18 | + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) |
| 19 | +} |
| 20 | + |
| 21 | +// WorkFunc is the function that holds the actual workload of a thread. |
| 22 | +// The ThreadManager cancels the provided context when being stopped, so the workload should listen to the context's Done channel. |
| 23 | +type WorkFunc func(context.Context) error |
| 24 | + |
| 25 | +// OnFinishFunc can be used to react to a thread finishing. |
| 26 | +// Note that its context might already be cancelled (if the ThreadManager is being stopped). |
| 27 | +type OnFinishFunc func(context.Context, ThreadReturn) |
| 28 | + |
| 29 | +// NewThreadManager creates a new ThreadManager. |
| 30 | +// The mgrCtx is used for two purposes: |
| 31 | +// 1. If the context is cancelled, the ThreadManager is stopped. Alternatively, its Stop() method can be called. |
| 32 | +// 2. If the context contains a logger, it is used for logging. |
| 33 | +// |
| 34 | +// If onFinish is not nil, it will be called whenever a thread finishes. It is called after the thread's own onFinish function, if any. |
| 35 | +func NewThreadManager(mgrCtx context.Context, onFinish OnFinishFunc) *ThreadManager { |
| 36 | + return &ThreadManager{ |
| 37 | + returns: make(chan ThreadReturn, 100), |
| 38 | + onFinish: onFinish, |
| 39 | + log: logging.FromContextOrDiscard(mgrCtx), |
| 40 | + runOnStart: map[string]*Thread{}, |
| 41 | + mgrStop: mgrCtx.Done(), |
| 42 | + threadCancelFuncs: map[string]context.CancelFunc{}, |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +type ThreadManager struct { |
| 47 | + lock sync.Mutex // generic lock for the ThreadManager |
| 48 | + lockThreadMap sync.Mutex // lock specifically for the threadCancelFuncs map |
| 49 | + returns chan ThreadReturn // channel to receive thread returns |
| 50 | + onFinish OnFinishFunc // function to call when a thread finishes |
| 51 | + log logging.Logger // logger for the ThreadManager |
| 52 | + runOnStart map[string]*Thread // is filled if threads are added before the ThreadManager is started |
| 53 | + mgrStop <-chan struct{} // channel to stop the ThreadManager |
| 54 | + stopped atomic.Bool // indicates if the ThreadManager is stopped |
| 55 | + waitForThreads sync.WaitGroup // used to wait for threads to finish when stopping the ThreadManager |
| 56 | + threadCancelFuncs map[string]context.CancelFunc // map of thread ids to cancel functions |
| 57 | +} |
| 58 | + |
| 59 | +// Start starts the ThreadManager. |
| 60 | +// This starts a goroutine that listens for thread returns and os signals. |
| 61 | +// Calling Start() multiple times is a no-op, unless the ThreadManager has already been stopped, then it panics. |
| 62 | +// It is possible to add threads before the ThreadManager is started, but they will only be run after Start() is called. |
| 63 | +// Threads added after Start() will be run immediately. |
| 64 | +// There are three ways to stop the ThreadManager again: |
| 65 | +// 1. Cancel the context passed to the ThreadManager during creation. |
| 66 | +// 2. Call the ThreadManager's Stop() method. |
| 67 | +// 3. Send a SIGINT or SIGTERM signal to the process. |
| 68 | +func (tm *ThreadManager) Start() { |
| 69 | + tm.lock.Lock() |
| 70 | + defer tm.lock.Unlock() |
| 71 | + if tm.stopped.Load() { |
| 72 | + panic("Start called on a stopped ThreadManager") |
| 73 | + } |
| 74 | + if tm.isStarted() { |
| 75 | + tm.log.Debug("Start called, but ThreadManager is already started, nothing to do") |
| 76 | + return |
| 77 | + } |
| 78 | + tm.log.Info("Starting ThreadManager") |
| 79 | + go func() { |
| 80 | + for { |
| 81 | + select { |
| 82 | + case tr, ok := <-tm.returns: |
| 83 | + if !ok { |
| 84 | + // channel has been closed, this means the Stop() method has been called |
| 85 | + return |
| 86 | + } |
| 87 | + if tr.Err != nil { |
| 88 | + tm.log.Error(tr.Err, "Error in thread", "thread", tr.Thread.id) |
| 89 | + } |
| 90 | + case sig := <-sigs: |
| 91 | + tm.log.Info("Received os signal, stopping ThreadManager", "signal", sig) |
| 92 | + tm.Stop() |
| 93 | + return |
| 94 | + case <-tm.mgrStop: |
| 95 | + tm.Stop() |
| 96 | + return |
| 97 | + } |
| 98 | + } |
| 99 | + }() |
| 100 | + runOnStart := tm.runOnStart |
| 101 | + tm.runOnStart = nil |
| 102 | + if len(runOnStart) > 0 { |
| 103 | + tm.log.Info("Running threads added before ThreadManager was started", "threadCount", len(runOnStart)) |
| 104 | + for _, t := range runOnStart { |
| 105 | + tm.run(t) |
| 106 | + } |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +// Stop stops the ThreadManager. |
| 111 | +// Panics if the ThreadManager has not been started yet. |
| 112 | +// Calling Stop() multiple times is a no-op. |
| 113 | +// It is not possible to start the ThreadManager again after it has been stopped, a new instance must be created. |
| 114 | +// Adding threads after the ThreadManager has been stopped is a no-op. |
| 115 | +// The ThreadManager is also stopped when the context passed to the ThreadManager during creation is cancelled or when a SIGINT or SIGTERM signal is received. |
| 116 | +func (tm *ThreadManager) Stop() { |
| 117 | + tm.lock.Lock() |
| 118 | + defer tm.lock.Unlock() |
| 119 | + if !tm.isStarted() { |
| 120 | + panic("Stop called on a ThreadManager that has not been started yet") |
| 121 | + } |
| 122 | + tm.stop() |
| 123 | +} |
| 124 | + |
| 125 | +func (tm *ThreadManager) stop() { |
| 126 | + if tm.stopped.Load() { |
| 127 | + tm.log.Debug("Stop called, but ThreadManager is already stopped, nothing to do") |
| 128 | + return |
| 129 | + } |
| 130 | + tm.log.Info("Stopping ThreadManager, waiting for remaining threads to finish") |
| 131 | + tm.stopped.Store(true) |
| 132 | + tm.lockThreadMap.Lock() |
| 133 | + for id, cancel := range tm.threadCancelFuncs { |
| 134 | + tm.log.Debug("Cancelling thread", "thread", id) |
| 135 | + cancel() |
| 136 | + } |
| 137 | + tm.lockThreadMap.Unlock() |
| 138 | + |
| 139 | + tm.waitForThreads.Wait() |
| 140 | + close(tm.returns) |
| 141 | + tm.log.Info("ThreadManager stopped") |
| 142 | +} |
| 143 | + |
| 144 | +// Run gives a new thread to run to the ThreadManager. |
| 145 | +// The context is used to create a new context with a cancel function for the thread. |
| 146 | +// id is used for logging and debugging purposes. |
| 147 | +// Note that when a thread with the same id as an already running thread is added, the running thread will be cancelled. |
| 148 | +// If the ThreadManager has not been started yet, the previously added thread with the conflicting id will be discarded and the newly added one will be run when the ThreadManager is started instead. |
| 149 | +// A thread MUST NOT start another thread with the same id as itself during its work function. If a thread wants to restart itself, this must happen in the onFinish function. |
| 150 | +// work is the actual workload of the thread. |
| 151 | +// onFinish can be used to react to the thread having finished. |
| 152 | +// There are some pre-defined functions that can be used as onFinish functions, e.g. the ThreadManager's Restart method. |
| 153 | +func (tm *ThreadManager) Run(ctx context.Context, id string, work func(context.Context) error, onFinish OnFinishFunc) { |
| 154 | + tm.RunThread(NewThread(ctx, id, work, onFinish)) |
| 155 | +} |
| 156 | + |
| 157 | +// RunThread is the same as Run, but takes a Thread struct instead of the individual parameters. |
| 158 | +func (tm *ThreadManager) RunThread(t Thread) { |
| 159 | + tm.lock.Lock() |
| 160 | + defer tm.lock.Unlock() |
| 161 | + tm.run(&t) |
| 162 | +} |
| 163 | + |
| 164 | +func (tm *ThreadManager) run(t *Thread) { |
| 165 | + if t == nil { |
| 166 | + tm.log.Error(nil, "run(t *Thread) called with nil Thread, this should never happen") |
| 167 | + return |
| 168 | + } |
| 169 | + if tm.stopped.Load() { |
| 170 | + tm.log.Info("Skipping thread run because ThreadManager is already stopped", "thread", t.id) |
| 171 | + return |
| 172 | + } |
| 173 | + if !tm.isStarted() { |
| 174 | + tm.log.Debug("ThreadManager has not been started yet, enqueuing thread to run on start", "thread", t.ID()) |
| 175 | + _, ok := tm.runOnStart[t.id] |
| 176 | + if ok { |
| 177 | + tm.log.Debug("Discarding thread with the same id that was already enqueued", "thread", t.id) |
| 178 | + } |
| 179 | + tm.runOnStart[t.id] = t |
| 180 | + return |
| 181 | + } |
| 182 | + tm.log.Debug("Running thread", "thread", t.id) |
| 183 | + tm.lockThreadMap.Lock() |
| 184 | + if cancel := tm.threadCancelFuncs[t.id]; cancel != nil { |
| 185 | + tm.log.Debug("A thread with the same id is already running, cancelling it", "thread", t.id) |
| 186 | + cancel() |
| 187 | + } |
| 188 | + tm.threadCancelFuncs[t.id] = t.cancel |
| 189 | + tm.lockThreadMap.Unlock() |
| 190 | + tm.waitForThreads.Add(1) |
| 191 | + go func() { |
| 192 | + defer tm.waitForThreads.Done() |
| 193 | + var err error |
| 194 | + if t.work != nil { |
| 195 | + err = t.work(t.ctx) |
| 196 | + } else { |
| 197 | + tm.log.Debug("Thread has no work function", "thread", t.id) |
| 198 | + } |
| 199 | + tm.lockThreadMap.Lock() |
| 200 | + // thread must be removed from the internal map here, because otherwise the thread might be restarted before the cancel function is removed |
| 201 | + // which would then wrongfully remove the cancel function of the new thread |
| 202 | + if cancelOld := tm.threadCancelFuncs[t.id]; cancelOld != nil { // this should always be true |
| 203 | + // cancel the thread's context, just to be sure that no running thread can 'leak' by losing its cancel function |
| 204 | + cancelOld() |
| 205 | + delete(tm.threadCancelFuncs, t.id) |
| 206 | + } |
| 207 | + tm.lockThreadMap.Unlock() |
| 208 | + tr := NewThreadReturn(t, err) |
| 209 | + if t.onFinish != nil { |
| 210 | + tm.log.Debug("Calling the thread's onFinish function", "thread", t.id) |
| 211 | + t.onFinish(t.ctx, tr) |
| 212 | + } |
| 213 | + if tm.onFinish != nil { |
| 214 | + tm.log.Debug("Calling the thread manager's onFinish function", "thread", tr.Thread.id) |
| 215 | + tm.onFinish(t.ctx, tr) |
| 216 | + } |
| 217 | + tm.returns <- tr |
| 218 | + tm.log.Debug("Thread finished", "thread", t.id) |
| 219 | + }() |
| 220 | +} |
| 221 | + |
| 222 | +func (tm *ThreadManager) isStarted() bool { |
| 223 | + return tm.runOnStart == nil |
| 224 | +} |
| 225 | + |
| 226 | +// IsStarted returns true if the ThreadManager has been started. |
| 227 | +// Note that this will return true if the ThreadManager has been started at some point, even if it has been stopped by now. |
| 228 | +func (tm *ThreadManager) IsStarted() bool { |
| 229 | + tm.lock.Lock() |
| 230 | + defer tm.lock.Unlock() |
| 231 | + return tm.isStarted() |
| 232 | +} |
| 233 | + |
| 234 | +// IsStopped returns true if the ThreadManager has been stopped. |
| 235 | +// Note that this will return false if the ThreadManager has not been started yet. |
| 236 | +func (tm *ThreadManager) IsStopped() bool { |
| 237 | + return tm.stopped.Load() |
| 238 | +} |
| 239 | + |
| 240 | +// IsRunning returns true if the ThreadManager is currently running, |
| 241 | +// meaning it has been started and not yet been stopped. |
| 242 | +// This is a convenience function that is equivalent to calling IsStarted() && !IsStopped(). |
| 243 | +func (tm *ThreadManager) IsRunning() bool { |
| 244 | + return tm.IsStarted() && !tm.IsStopped() |
| 245 | +} |
| 246 | + |
| 247 | +var _ OnFinishFunc = (*ThreadManager)(nil).Restart |
| 248 | + |
| 249 | +// Restart is a pre-defined onFinish function that can be used to restart a thread after it has finished. |
| 250 | +// This method is not meant to be called directly, instead pass it to the ThreadManager's Run method as the onFinish parameter: |
| 251 | +// |
| 252 | +// tm.Run(ctx, "myThread", myWorkFunc, tm.Restart) |
| 253 | +func (tm *ThreadManager) Restart(_ context.Context, tr ThreadReturn) { |
| 254 | + if tm.stopped.Load() { |
| 255 | + return |
| 256 | + } |
| 257 | + tm.RunThread(*tr.Thread) |
| 258 | +} |
| 259 | + |
| 260 | +var _ OnFinishFunc = (*ThreadManager)(nil).RestartOnError |
| 261 | + |
| 262 | +// RestartOnError is a pre-defined onFinish function that can be used to restart a thread after it has finished, if it finished with an error. |
| 263 | +// It is the opposite of RestartOnSuccess. |
| 264 | +// This method is not meant to be called directly, instead pass it to the ThreadManager's Run method as the onFinish parameter: |
| 265 | +// |
| 266 | +// tm.Run(ctx, "myThread", myWorkFunc, tm.RestartOnError) |
| 267 | +func (tm *ThreadManager) RestartOnError(ctx context.Context, tr ThreadReturn) { |
| 268 | + if tr.Err != nil { |
| 269 | + tm.Restart(ctx, tr) |
| 270 | + } |
| 271 | +} |
| 272 | + |
| 273 | +var _ OnFinishFunc = (*ThreadManager)(nil).RestartOnSuccess |
| 274 | + |
| 275 | +// RestartOnSuccess is a pre-defined onFinish function that can be used to restart a thread after it has finished, if it didn't throw an error. |
| 276 | +// It is the opposite of RestartOnError. |
| 277 | +// This method is not meant to be called directly, instead pass it to the ThreadManager's Run method as the onFinish parameter: |
| 278 | +// |
| 279 | +// tm.Run(ctx, "myThread", myWorkFunc, tm.RestartOnSuccess) |
| 280 | +func (tm *ThreadManager) RestartOnSuccess(ctx context.Context, tr ThreadReturn) { |
| 281 | + if tr.Err == nil { |
| 282 | + tm.Restart(ctx, tr) |
| 283 | + } |
| 284 | +} |
| 285 | + |
| 286 | +// NewThread creates a new thread with the given id, work function and onFinish function. |
| 287 | +// It is usually not required to call this function directly, instead use the ThreadManager's Run method. |
| 288 | +// A new context with a cancel function is derived from the context passed to the constructor. |
| 289 | +// The Thread's fields are considered immutable after creation. |
| 290 | +func NewThread(ctx context.Context, id string, work WorkFunc, onFinish OnFinishFunc) Thread { |
| 291 | + ctx, cancel := context.WithCancel(ctx) |
| 292 | + return Thread{ |
| 293 | + ctx: ctx, |
| 294 | + cancel: cancel, |
| 295 | + id: id, |
| 296 | + work: work, |
| 297 | + onFinish: onFinish, |
| 298 | + } |
| 299 | +} |
| 300 | + |
| 301 | +// Thread represents a thread that can be run by the ThreadManager. |
| 302 | +type Thread struct { |
| 303 | + ctx context.Context |
| 304 | + cancel context.CancelFunc |
| 305 | + id string |
| 306 | + work WorkFunc |
| 307 | + onFinish OnFinishFunc |
| 308 | +} |
| 309 | + |
| 310 | +// Context returns the context of the thread. |
| 311 | +func (t *Thread) Context() context.Context { |
| 312 | + return t.ctx |
| 313 | +} |
| 314 | + |
| 315 | +// Cancel cancels the thread's context. |
| 316 | +// The thread manager cancels all threads' contexts when it is stopped, so calling this manually is usually not necessary. |
| 317 | +func (t *Thread) Cancel() { |
| 318 | + t.cancel() |
| 319 | +} |
| 320 | + |
| 321 | +// ID returns the id of the thread. |
| 322 | +func (t *Thread) ID() string { |
| 323 | + return t.id |
| 324 | +} |
| 325 | + |
| 326 | +// WorkFunc returns the workload function of the thread. |
| 327 | +func (t *Thread) WorkFunc() WorkFunc { |
| 328 | + return t.work |
| 329 | +} |
| 330 | + |
| 331 | +// OnFinishFunc returns the onFinish function of the thread. |
| 332 | +func (t *Thread) OnFinishFunc() OnFinishFunc { |
| 333 | + return t.onFinish |
| 334 | +} |
| 335 | + |
| 336 | +// NewThreadReturn constructs a new ThreadReturn object. |
| 337 | +// This is used by the ThreadManager internally and it should rarely be necessary to call this function directly. |
| 338 | +func NewThreadReturn(thread *Thread, err error) ThreadReturn { |
| 339 | + return ThreadReturn{ |
| 340 | + Err: err, |
| 341 | + Thread: thread, |
| 342 | + } |
| 343 | +} |
| 344 | + |
| 345 | +// ThreadReturn represents the result of a thread's execution. |
| 346 | +// It contains a reference to the thread and an error, if any occurred. |
| 347 | +type ThreadReturn struct { |
| 348 | + Err error |
| 349 | + Thread *Thread |
| 350 | +} |
0 commit comments