Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 59 additions & 17 deletions src/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,36 +152,61 @@ def list_node_allocations(self) -> List[Tuple[str, int, int]]:
return self.layer_allocator.list_node_allocations()

# Warm-up and re-shard
def _run_warmup_and_truncate(self) -> None:
def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None:
"""Run a brief warm-up to detect truncation points and shrink shards.

Uses layer-level DP turning points (node_id, layer_idx, kind):
- kind == "tail": drop [layer_idx, end) on that node
- kind == "head": drop [start, layer_idx) on that node

Note: Always uses DynamicProgrammingRouting for finding turning points,
regardless of the current request_router type, since turning points
detection requires layer-level DP analysis.

Args:
override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard.
Default is 0, which means use request_warm_up_for_reshard.
"""
nodes_list = list(self.nodes)
if not nodes_list:
return
num_layers = self.model_info.num_layers

# The number of warm-up requests can be used to repeat detection, but a
# single pass is sufficient with our DP model; we repeat to smooth noise.
warmup_count = (
override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard
)

# Always use DP router for finding turning points, regardless of current router type
# This is because turning points detection requires layer-level DP analysis
dp_router = DynamicProgrammingRouting()

agg_turns: Dict[Tuple[str, int, str], int] = {}
for _ in range(self.request_warm_up_for_reshard):
turns = self.request_router.find_turning_points(nodes_list, num_layers)
for _ in range(warmup_count):
turns = dp_router.find_turning_points(nodes_list, num_layers)
for t in turns:
agg_turns[t] = agg_turns.get(t, 0) + 1

if not agg_turns:
return

# Apply truncation for consistently observed turning points
# Note: Must use layer_allocator.allocate/deallocate to properly update
# internal state (node_allocation dict and layer_to_load)
for node_id, layer_idx, kind in agg_turns:
node = next((n for n in self.nodes if n.node_id == node_id), None)
if node is None or node.start_layer is None or node.end_layer is None:
continue
start, end = node.start_layer, node.end_layer
if kind == "tail":
if layer_idx < end:
node.set_layer_allocation(start, layer_idx)
self.layer_allocator.deallocate(node)
self.layer_allocator.allocate(node, start, layer_idx)
elif kind == "head":
if layer_idx > start:
node.set_layer_allocation(layer_idx, end)
self.layer_allocator.deallocate(node)
self.layer_allocator.allocate(node, layer_idx, end)

def update_node_info(
self,
Expand Down Expand Up @@ -291,6 +316,22 @@ def join(self, node: Node, bootstrap: bool = False) -> None:
with self._node_count_cv:
self._node_count_cv.notify_all()

def _perform_global_rebalance(self) -> None:
"""Perform global rebalancing: deallocate all nodes and reallocate."""
logger.debug("Performing global rebalance")
self._bootstrapped = False
self._bootstrapped_event.clear()
for n in self.nodes:
if n.start_layer is not None and n.end_layer is not None:
self.layer_allocator.deallocate(n)
success = self.layer_allocator.global_allocation()
if not success:
logger.warning("Global rebalance failed to produce a full pipeline")
else:
logger.debug("Global rebalance completed successfully")
self._bootstrapped = True
self._bootstrapped_event.set()

def leave(self, node_id: str) -> None:
"""Remove a node from allocation and refresh plan and materialized nodes."""
if node_id not in self.layer_allocator.node_id_to_node:
Expand All @@ -316,19 +357,20 @@ def leave(self, node_id: str) -> None:
f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance"
)
else:
# All nodes are automatic, proceed with rebalance
self._bootstrapped = False
self._bootstrapped_event.clear()
for n in self.nodes:
if n.start_layer is not None and n.end_layer is not None:
self.layer_allocator.deallocate(n)
success = self.layer_allocator.global_allocation()
if not success:
logger.warning("Global rebalance failed to produce a full pipeline")
# All nodes are automatic, try adjustment first, then rebalance if needed
if not self.layer_allocator.has_full_pipeline():
logger.debug(
"No full pipeline after node leave, attempting warmup and truncate"
)
self._run_warmup_and_truncate(override_warmup_count=1)
if not self.layer_allocator.has_full_pipeline():
self._perform_global_rebalance()
else:
logger.debug(
"Pipeline recovered through warmup and truncate, skipping global rebalance"
)
else:
logger.debug("Global rebalance completed successfully")
self._bootstrapped = True
self._bootstrapped_event.set()
self._perform_global_rebalance()

with self._node_count_cv:
self._node_count_cv.notify_all()
Expand Down
Loading