Skip to content

Commit f55bac8

Browse files
allenwang28Allen Wang
andauthored
Add a simple weight sync sandbox (#531)
Co-authored-by: Allen Wang <allencwang@fb.com>
1 parent fa456c7 commit f55bac8

File tree

2 files changed

+279
-0
lines changed

2 files changed

+279
-0
lines changed

tests/sandbox/weight_sync/main.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Weight Sync Sandbox
9+
10+
A minimal test environment focused exclusively on testing the weight synchronization
11+
mechanism between RLTrainer and Generator.
12+
13+
Usage:
14+
python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml
15+
"""
16+
17+
import asyncio
18+
import time
19+
20+
import torch
21+
import torchstore as ts
22+
from forge.actors._torchstore_utils import rdma_enabled
23+
from forge.actors.generator import Generator
24+
from forge.actors.trainer import RLTrainer
25+
from forge.controller.provisioner import init_provisioner, shutdown
26+
from forge.observability.metric_actors import get_or_create_metric_logger
27+
from forge.types import LauncherConfig, ProvisionerConfig
28+
from forge.util.config import parse
29+
from omegaconf import DictConfig
30+
from vllm.transformers_utils.tokenizer import get_tokenizer
31+
32+
33+
def generate_random_batch(
34+
local_batch_size: int,
35+
request_len: int,
36+
response_len: int,
37+
vocab_size: int = 32000,
38+
device: str = "cuda",
39+
dp_size: int = 1,
40+
):
41+
"""
42+
Generate random input and target tensors for a single training step.
43+
Creates one batch per data parallel rank.
44+
"""
45+
inputs = []
46+
targets = []
47+
48+
# Create one batch for each data parallel rank
49+
for _ in range(dp_size):
50+
request = torch.randint(
51+
1,
52+
vocab_size,
53+
(local_batch_size, request_len),
54+
dtype=torch.long,
55+
device=device,
56+
)
57+
response = torch.randint(
58+
1,
59+
vocab_size,
60+
(local_batch_size, response_len),
61+
dtype=torch.long,
62+
device=device,
63+
)
64+
65+
# Create padding mask
66+
padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1
67+
68+
ref_logprobs = (
69+
-torch.abs(torch.randn((local_batch_size, response_len), device=device))
70+
- 1.0
71+
)
72+
advantages = torch.randn((local_batch_size, 1), device=device)
73+
input_tokens = torch.cat([request, response], dim=1)
74+
inputs.append({"tokens": input_tokens})
75+
targets.append(
76+
{
77+
"response": response,
78+
"ref_logprobs": ref_logprobs,
79+
"advantages": advantages,
80+
"padding_mask": padding_mask,
81+
}
82+
)
83+
84+
return inputs, targets
85+
86+
87+
async def main(cfg: DictConfig):
88+
local_batch_size = cfg.get("local_batch_size", None)
89+
assert local_batch_size is not None, "local_batch_size must be specified"
90+
91+
request_len = cfg.get("max_req_tokens", 64)
92+
response_len = cfg.get("max_res_tokens", 64)
93+
model_name = cfg.get("model")
94+
95+
print(f"Loading tokenizer for model: {model_name}")
96+
tokenizer = get_tokenizer(model_name)
97+
vocab_size = tokenizer.vocab_size
98+
print(f"Detected vocab size: {vocab_size}")
99+
100+
trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1)
101+
dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1
102+
103+
# ---- Global setups ---- #
104+
provisioner = None
105+
if cfg.get("provisioner", None) is not None:
106+
provisioner = await init_provisioner(
107+
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
108+
)
109+
else:
110+
provisioner = await init_provisioner()
111+
112+
metric_logging_cfg = cfg.get("metric_logging", {})
113+
mlogger = await get_or_create_metric_logger(process_name="Controller")
114+
await mlogger.init_backends.call_one(metric_logging_cfg)
115+
116+
# Initialize torchstore
117+
await ts.initialize(strategy=ts.ControllerStorageVolumes())
118+
119+
print("=" * 80)
120+
print(f"Model: {model_name}")
121+
print(f"Local batch size: {local_batch_size}")
122+
print(
123+
f"Sequence length: {request_len + response_len} ({request_len} + {response_len})"
124+
)
125+
print(f"Data parallel size: {dp_size}")
126+
print(f"Is RDMA available? {rdma_enabled()}")
127+
print("=" * 80 + "\n")
128+
129+
# Initialize trainer and generator
130+
print("Initializing trainer and generator...")
131+
init_start = time.time()
132+
133+
trainer, policy = await asyncio.gather(
134+
RLTrainer.options(**cfg.actors.trainer).as_actor(
135+
**cfg.trainer,
136+
loss=lambda *args, **kwargs: torch.tensor(
137+
1.0, requires_grad=True, device="cuda"
138+
),
139+
),
140+
Generator.options(**cfg.actors.policy).as_actor(**cfg.policy),
141+
)
142+
143+
init_time = time.time() - init_start
144+
print(f"Finished initialization in ({init_time:.2f}s)")
145+
146+
# Run one training step to create weight delta
147+
print("Running single training step...")
148+
step_start = time.time()
149+
150+
inputs, targets = generate_random_batch(
151+
local_batch_size=local_batch_size,
152+
request_len=request_len,
153+
response_len=response_len,
154+
vocab_size=vocab_size,
155+
dp_size=dp_size,
156+
)
157+
158+
await trainer.train_step.call(inputs, targets)
159+
step_time = time.time() - step_start
160+
print(f"Finished train step in ({step_time:.2f}s)\n")
161+
162+
# Test push_weights
163+
print("Pushing weights from trainer to store...")
164+
push_start = time.time()
165+
166+
await trainer.push_weights.call(policy_version=1)
167+
168+
push_time = time.time() - push_start
169+
print(f"Finished weights push in ({push_time:.2f}s)\n")
170+
171+
# Test update_weights
172+
print("Updating generator weights from store...")
173+
update_start = time.time()
174+
175+
await policy.update_weights.call(version=1)
176+
177+
update_time = time.time() - update_start
178+
print(f"Updated generator weights ({update_time:.2f}s)\n")
179+
180+
# TODO - ideally we have the capability to check forward passes between
181+
# the trainer/generator to verify correctness. This would require adding
182+
# forward capabilities to both trainer/generator actors.
183+
184+
# Summary
185+
print("=" * 80)
186+
print("Results")
187+
print("=" * 80)
188+
print(f"Push time: {push_time:.2f}s")
189+
print(f"Update time: {update_time:.2f}s")
190+
print(f"Total sync time: {push_time + update_time:.2f}s")
191+
print("=" * 80 + "\n")
192+
193+
# Cleanup
194+
print("Shutting down...")
195+
await shutdown()
196+
print("Shutdown complete.")
197+
198+
199+
if __name__ == "__main__":
200+
201+
@parse
202+
def _main(cfg):
203+
asyncio.run(main(cfg))
204+
205+
_main()
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Weight Sync Sandbox Configuration
2+
# >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml
3+
4+
model: "Qwen/Qwen3-1.7B"
5+
local_batch_size: 4
6+
max_req_tokens: 64
7+
max_res_tokens: 64
8+
9+
metric_logging:
10+
console:
11+
logging_mode: global_reduce
12+
13+
policy:
14+
prefetch_weights_to_shm: false # Disable to avoid shared memory warnings in test
15+
engine_args:
16+
model: ${model}
17+
tensor_parallel_size: 1
18+
pipeline_parallel_size: 1
19+
enforce_eager: true
20+
sampling_params:
21+
n: 1
22+
max_tokens: 32 # Just for verification forward pass
23+
temperature: 1.0
24+
top_p: 1.0
25+
26+
trainer:
27+
model:
28+
name: qwen3
29+
flavor: 1.7B
30+
hf_assets_path: hf://${model}
31+
optimizer:
32+
name: AdamW
33+
lr: 1e-5
34+
eps: 1e-8
35+
lr_scheduler:
36+
warmup_steps: 1
37+
training:
38+
local_batch_size: ${local_batch_size}
39+
seq_len: 128 # max_req_tokens + max_res_tokens
40+
max_norm: 1.0
41+
steps: 1 # We only run 1 step
42+
dtype: bfloat16
43+
gc_freq: 1
44+
compile:
45+
enable: false
46+
parallelism:
47+
data_parallel_replicate_degree: 1
48+
data_parallel_shard_degree: 1 # Single GPU, no FSDP
49+
tensor_parallel_degree: 1
50+
pipeline_parallel_degree: 1
51+
context_parallel_degree: 1
52+
expert_parallel_degree: 1
53+
disable_loss_parallel: true
54+
checkpoint:
55+
enable: true
56+
folder: ./checkpoint
57+
initial_load_path: hf://${model}
58+
initial_load_in_hf: true
59+
last_save_in_hf: true
60+
async_mode: "disabled"
61+
activation_checkpoint:
62+
mode: selective
63+
selective_ac_option: op
64+
65+
# Resource allocation - both as actors
66+
actors:
67+
policy:
68+
procs: 1 # Single process for generator
69+
with_gpus: true
70+
mesh_name: policy
71+
trainer:
72+
procs: 1 # Single process for trainer
73+
with_gpus: true
74+
mesh_name: trainer

0 commit comments

Comments
 (0)