-
Notifications
You must be signed in to change notification settings - Fork 51
Add Multi-Node Distributed Training Support for SLURM Clusters #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Multi-Node Distributed Training Support for SLURM Clusters #528
Conversation
- Add multi-node training support in main.py with proper LOCAL_RANK calculation - Add qwen3_32b.yaml config optimized for 32-node, 128 GPU training - Add qwen3_32b.yaml config for GRPO training - Update launcher.py with SLURM resource auto-detection from environment - Update types.py with necessary type definitions Key features: - Proper multi-node LOCAL_RANK: rank % gpus_per_node (fixes cross-node issues) - Provisioner support for SLURM multi-node orchestration - SLURM resource inference from environment variables and scontrol - Configurable data loading: num_shards_per_rank, num_dataloader_workers - Optimized training config: TP=4, FSDP=32, selective AC every 2 layers - Async checkpointing enabled for non-blocking saves - Backward compatibility with legacy 'processes' config Optimizations applied: - Activation checkpointing: selective with layer frequency 2 (2-3x faster) - Async checkpointing: non-blocking background saves - Batch size 8 with gradient accumulation 2 for convergence - 64 shards per rank for optimal I/O parallelism - SLURM_SWITCHES=2 for network locality (18 nodes/block topology) Tested on: - 32 nodes × 4 GPUs = 128 total GPUs - Ethernet network with SLURM block topology - Qwen3-32B model (32B parameters) - 5,000 training steps with WandB logging
allenwang28
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this mostly looks fine to me, but I would like @daniellepintz to take a look as well over the SFT pieces!
|
|
||
| super().__init__(job_config) | ||
|
|
||
| def _init_dist(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we remove the _init_dist altogether would this still work? I added this line in get_proc_mesh later, so this should not be needed anymore. Could you please try it out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to test this. Right now, I pass the local rank and NCCL variables within the env there. Will keep you posted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should split out the SLURM specific PR from the SFT PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm that's a reasonable approach. I'll think of how to separate them and raise a new one
Add Multi-Node Distributed Training Support for SLURM Clusters
Summary
This PR adds comprehensive multi-node distributed training support to Forge, enabling scalable training on SLURM-managed GPU clusters. Successfully tested on a 32-node cluster with 128 GPUs (4 GPUs per node) training Qwen3-32B.
Motivation
The existing Forge implementation only supported single-node multi-GPU training. This PR extends support to multi-node environments, which is essential for:
Key Changes
1. Multi-Node LOCAL_RANK Fix (
apps/sft/main.py)LOCAL_RANK = RANK, which breaks on multi-node setups (node 2 would have LOCAL_RANK=4 instead of 0)LOCAL_RANK = RANK % gpus_per_nodeto ensure proper GPU assignment per node2. SLURM Provisioner Integration (
apps/sft/main.py)actorsand legacyprocessesconfig for backward compatibility3. Smart SLURM Resource Detection (
src/forge/controller/launcher.py)SLURM_CPUS_ON_NODE,SLURM_MEM_PER_NODE,SLURM_GPUS_PER_NODE)scontrol show nodewhen env vars unavailable"4"vs"gpu:4")4. Configurable Data Loading (
apps/sft/main.py)num_shards_per_rankparameter (default: 64 for large datasets, 8 for small)num_dataloader_workersparameter (default: 0 to avoid CUDA fork issues)Testing
Files Changed
apps/sft/main.py- Multi-node support, provisioner integration, data configapps/sft/qwen3_32b.yaml- New optimized 128-GPU config (renamed from qwen3_32b_multinode.yaml)**qwen3_32b.yaml**- GRPO config for Qwen3-32Bsrc/forge/controller/launcher.py- SLURM resource auto-detectionsrc/forge/types.py- Type definitions for launcher configUsage Example
YAML Config:
provisioner: launcher: slurm cpu: 128 memory_mb: 1655502 gpus_per_node: 4 actors: trainer: procs: 4 hosts: 32 with_gpus: trueRun Training:
python -m apps.sft.main --config apps/sft/qwen3_32b.yaml