From c9efa008f16172d931fd3b0df3be55c99acc3f60 Mon Sep 17 00:00:00 2001 From: jason Date: Fri, 7 Nov 2025 23:29:07 +0800 Subject: [PATCH] warmup before global rebalance --- src/scheduling/scheduler.py | 76 ++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 17 deletions(-) diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 8d9ab23..3a7b0d2 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -152,25 +152,48 @@ def list_node_allocations(self) -> List[Tuple[str, int, int]]: return self.layer_allocator.list_node_allocations() # Warm-up and re-shard - def _run_warmup_and_truncate(self) -> None: + def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None: """Run a brief warm-up to detect truncation points and shrink shards. Uses layer-level DP turning points (node_id, layer_idx, kind): - kind == "tail": drop [layer_idx, end) on that node - kind == "head": drop [start, layer_idx) on that node + + Note: Always uses DynamicProgrammingRouting for finding turning points, + regardless of the current request_router type, since turning points + detection requires layer-level DP analysis. + + Args: + override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard. + Default is 0, which means use request_warm_up_for_reshard. """ nodes_list = list(self.nodes) if not nodes_list: return num_layers = self.model_info.num_layers + # The number of warm-up requests can be used to repeat detection, but a # single pass is sufficient with our DP model; we repeat to smooth noise. + warmup_count = ( + override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard + ) + + # Always use DP router for finding turning points, regardless of current router type + # This is because turning points detection requires layer-level DP analysis + dp_router = DynamicProgrammingRouting() + agg_turns: Dict[Tuple[str, int, str], int] = {} - for _ in range(self.request_warm_up_for_reshard): - turns = self.request_router.find_turning_points(nodes_list, num_layers) + for _ in range(warmup_count): + turns = dp_router.find_turning_points(nodes_list, num_layers) for t in turns: agg_turns[t] = agg_turns.get(t, 0) + 1 + + if not agg_turns: + return + # Apply truncation for consistently observed turning points + # Note: Must use layer_allocator.allocate/deallocate to properly update + # internal state (node_allocation dict and layer_to_load) for node_id, layer_idx, kind in agg_turns: node = next((n for n in self.nodes if n.node_id == node_id), None) if node is None or node.start_layer is None or node.end_layer is None: @@ -178,10 +201,12 @@ def _run_warmup_and_truncate(self) -> None: start, end = node.start_layer, node.end_layer if kind == "tail": if layer_idx < end: - node.set_layer_allocation(start, layer_idx) + self.layer_allocator.deallocate(node) + self.layer_allocator.allocate(node, start, layer_idx) elif kind == "head": if layer_idx > start: - node.set_layer_allocation(layer_idx, end) + self.layer_allocator.deallocate(node) + self.layer_allocator.allocate(node, layer_idx, end) def update_node_info( self, @@ -291,6 +316,22 @@ def join(self, node: Node, bootstrap: bool = False) -> None: with self._node_count_cv: self._node_count_cv.notify_all() + def _perform_global_rebalance(self) -> None: + """Perform global rebalancing: deallocate all nodes and reallocate.""" + logger.debug("Performing global rebalance") + self._bootstrapped = False + self._bootstrapped_event.clear() + for n in self.nodes: + if n.start_layer is not None and n.end_layer is not None: + self.layer_allocator.deallocate(n) + success = self.layer_allocator.global_allocation() + if not success: + logger.warning("Global rebalance failed to produce a full pipeline") + else: + logger.debug("Global rebalance completed successfully") + self._bootstrapped = True + self._bootstrapped_event.set() + def leave(self, node_id: str) -> None: """Remove a node from allocation and refresh plan and materialized nodes.""" if node_id not in self.layer_allocator.node_id_to_node: @@ -316,19 +357,20 @@ def leave(self, node_id: str) -> None: f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance" ) else: - # All nodes are automatic, proceed with rebalance - self._bootstrapped = False - self._bootstrapped_event.clear() - for n in self.nodes: - if n.start_layer is not None and n.end_layer is not None: - self.layer_allocator.deallocate(n) - success = self.layer_allocator.global_allocation() - if not success: - logger.warning("Global rebalance failed to produce a full pipeline") + # All nodes are automatic, try adjustment first, then rebalance if needed + if not self.layer_allocator.has_full_pipeline(): + logger.debug( + "No full pipeline after node leave, attempting warmup and truncate" + ) + self._run_warmup_and_truncate(override_warmup_count=1) + if not self.layer_allocator.has_full_pipeline(): + self._perform_global_rebalance() + else: + logger.debug( + "Pipeline recovered through warmup and truncate, skipping global rebalance" + ) else: - logger.debug("Global rebalance completed successfully") - self._bootstrapped = True - self._bootstrapped_event.set() + self._perform_global_rebalance() with self._node_count_cv: self._node_count_cv.notify_all()