|
| 1 | +// Copyright 2025 The Cockroach Authors. |
| 2 | +// |
| 3 | +// Use of this software is governed by the CockroachDB Software License |
| 4 | +// included in the /LICENSE file. |
| 5 | + |
| 6 | +// Package taskset provides a generic work distribution mechanism for |
| 7 | +// coordinating parallel workers. TaskSet hands out integer identifiers |
| 8 | +// (TaskIDs) that workers can claim and process. The TaskIDs themselves have no |
| 9 | +// inherent meaning - it's up to the caller to map each TaskID to actual work |
| 10 | +// (e.g., file indices, key ranges, batch numbers, etc.). |
| 11 | +// |
| 12 | +// Example usage: |
| 13 | +// |
| 14 | +// tasks := taskset.MakeTaskSet(100, 4) // 100 work items, 4 workers |
| 15 | +// |
| 16 | +// // Worker goroutine |
| 17 | +// for taskID := tasks.ClaimFirst(); !taskID.IsDone(); taskID = tasks.ClaimNext(taskID) { |
| 18 | +// // Map taskID to actual work |
| 19 | +// processFile(files[taskID]) |
| 20 | +// // or: processKeyRange(splits[taskID], splits[taskID+1]) |
| 21 | +// // or: processBatch(taskID*batchSize, (taskID+1)*batchSize) |
| 22 | +// } |
| 23 | +package taskset |
| 24 | + |
| 25 | +// TaskID is an abstract integer identifier for a unit of work. The TaskID |
| 26 | +// itself has no inherent meaning - callers decide what each TaskID represents |
| 27 | +// (e.g., which file to process, which key range to handle, etc.). |
| 28 | +type TaskID int64 |
| 29 | + |
| 30 | +// taskIDDone is an internal sentinel value indicating no more tasks are available. |
| 31 | +// Use TaskID.IsDone() to check if a task is done. |
| 32 | +const taskIDDone = TaskID(-1) |
| 33 | + |
| 34 | +func (t TaskID) IsDone() bool { |
| 35 | + return t == taskIDDone |
| 36 | +} |
| 37 | + |
| 38 | +// MakeTaskSet creates a new TaskSet with taskCount work items numbered 0 |
| 39 | +// through taskCount-1, pre-split for the expected number of workers. |
| 40 | +// |
| 41 | +// The TaskIDs are abstract identifiers with no inherent meaning - the caller |
| 42 | +// decides what each TaskID represents. For example: |
| 43 | +// - File processing: MakeTaskSet(100, 4) with TaskID N → files[N] |
| 44 | +// - Key ranges: MakeTaskSet(100, 4) with TaskID N → range [splits[N-1], splits[N]) |
| 45 | +// - Row batches: MakeTaskSet(100, 4) with TaskID N → rows [N*1000, (N+1)*1000) |
| 46 | +// |
| 47 | +// The numWorkers parameter enables better initial load balancing by dividing the |
| 48 | +// task range into numWorkers equal spans upfront. For example, with 100 tasks |
| 49 | +// and 4 workers: |
| 50 | +// - Worker 1: starts with task 0 from range [0, 25) |
| 51 | +// - Worker 2: starts with task 25 from range [25, 50) |
| 52 | +// - Worker 3: starts with task 50 from range [50, 75) |
| 53 | +// - Worker 4: starts with task 75 from range [75, 100) |
| 54 | +// |
| 55 | +// Each worker continues claiming sequential tasks from their region (maintaining |
| 56 | +// locality), and can steal from other regions if they finish early. |
| 57 | +// |
| 58 | +// If the number of workers is unknown, use numWorkers=1 for a single span. |
| 59 | +func MakeTaskSet(taskCount, numWorkers int64) TaskSet { |
| 60 | + if numWorkers <= 0 { |
| 61 | + numWorkers = 1 |
| 62 | + } |
| 63 | + if taskCount <= 0 { |
| 64 | + return TaskSet{unassigned: nil} |
| 65 | + } |
| 66 | + |
| 67 | + // Pre-split the task range into numWorkers equal spans |
| 68 | + spans := make([]taskSpan, 0, numWorkers) |
| 69 | + tasksPerWorker := taskCount / numWorkers |
| 70 | + remainder := taskCount % numWorkers |
| 71 | + |
| 72 | + start := TaskID(0) |
| 73 | + for i := int64(0); i < numWorkers; i++ { |
| 74 | + // Distribute remainder evenly by giving first 'remainder' workers one extra task |
| 75 | + spanSize := tasksPerWorker |
| 76 | + if i < remainder { |
| 77 | + spanSize++ |
| 78 | + } |
| 79 | + if spanSize > 0 { |
| 80 | + end := start + TaskID(spanSize) |
| 81 | + spans = append(spans, taskSpan{start: start, end: end}) |
| 82 | + start = end |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + return TaskSet{unassigned: spans} |
| 87 | +} |
| 88 | + |
| 89 | +// TaskSet is a generic work distribution coordinator that manages a collection |
| 90 | +// of abstract task identifiers (TaskIDs) that can be claimed by workers. |
| 91 | +// |
| 92 | +// TaskSet implements a work-stealing algorithm optimized for task locality: |
| 93 | +// - When a worker completes task N, it tries to claim task N+1 (sequential locality) |
| 94 | +// - If task N+1 is unavailable, it falls back to round-robin claiming from the first span |
| 95 | +// - This balances load across workers while maintaining locality within each worker |
| 96 | +// |
| 97 | +// The TaskIDs themselves are just integers (0 through taskCount-1) with no |
| 98 | +// inherent meaning. Callers map these identifiers to actual work units such as: |
| 99 | +// - File indices (TaskID 5 → process files[5]) |
| 100 | +// - Key ranges (TaskID 5 → process range [splits[4], splits[5])) |
| 101 | +// - Batch numbers (TaskID 5 → process rows [5000, 6000)) |
| 102 | +// |
| 103 | +// TaskSet is NOT safe for concurrent use. Callers must ensure external |
| 104 | +// synchronization if the TaskSet is accessed from multiple goroutines. |
| 105 | +type TaskSet struct { |
| 106 | + unassigned []taskSpan |
| 107 | +} |
| 108 | + |
| 109 | +// ClaimFirst should be called when a worker claims its first task. It returns |
| 110 | +// an abstract TaskID to process. The caller decides what this TaskID represents |
| 111 | +// (e.g., which file to process, which key range to handle). Returns a TaskID |
| 112 | +// where .IsDone() is true if no tasks are available. |
| 113 | +// |
| 114 | +// ClaimFirst is distinct from ClaimNext because ClaimFirst will always take |
| 115 | +// from the first span and rotate it to the end (round-robin), whereas ClaimNext |
| 116 | +// tries to claim the next sequential task for locality. |
| 117 | +func (t *TaskSet) ClaimFirst() TaskID { |
| 118 | + if len(t.unassigned) == 0 { |
| 119 | + return taskIDDone |
| 120 | + } |
| 121 | + |
| 122 | + // Take the first task from the first span, then rotate that span to the end. |
| 123 | + // This provides round-robin distribution, ensuring each worker gets tasks |
| 124 | + // from different regions initially for better load balancing. |
| 125 | + span := t.unassigned[0] |
| 126 | + if span.size() == 0 { |
| 127 | + return taskIDDone |
| 128 | + } |
| 129 | + |
| 130 | + task := span.start |
| 131 | + span.start += 1 |
| 132 | + |
| 133 | + if span.size() == 0 { |
| 134 | + // Span is exhausted, remove it |
| 135 | + t.removeSpan(0) |
| 136 | + } else { |
| 137 | + // Move the span to the end for round-robin distribution |
| 138 | + t.unassigned = append(t.unassigned[1:], span) |
| 139 | + } |
| 140 | + |
| 141 | + return task |
| 142 | +} |
| 143 | + |
| 144 | +// ClaimNext should be called when a worker has completed its current task. It |
| 145 | +// returns the next abstract TaskID to process. The caller decides what this |
| 146 | +// TaskID represents. Returns a TaskID where .IsDone() is true if no tasks are |
| 147 | +// available. |
| 148 | +// |
| 149 | +// ClaimNext optimizes for locality by attempting to claim lastTask+1 first. If |
| 150 | +// that task is unavailable, it falls back to ClaimFirst behavior (round-robin |
| 151 | +// from the first span). |
| 152 | +func (t *TaskSet) ClaimNext(lastTask TaskID) TaskID { |
| 153 | + next := lastTask + 1 |
| 154 | + |
| 155 | + for i, span := range t.unassigned { |
| 156 | + if span.start != next { |
| 157 | + continue |
| 158 | + } |
| 159 | + |
| 160 | + span.start += 1 |
| 161 | + |
| 162 | + if span.size() == 0 { |
| 163 | + t.removeSpan(i) |
| 164 | + return next |
| 165 | + } |
| 166 | + |
| 167 | + t.unassigned[i] = span |
| 168 | + return next |
| 169 | + } |
| 170 | + |
| 171 | + // If we didn't find the next task in the unassigned set, then we've |
| 172 | + // exhausted the span and need to claim from a different span. |
| 173 | + return t.ClaimFirst() |
| 174 | +} |
| 175 | + |
| 176 | +func (t *TaskSet) removeSpan(index int) { |
| 177 | + t.unassigned = append(t.unassigned[:index], t.unassigned[index+1:]...) |
| 178 | +} |
0 commit comments