Skip to content

Commit 8b25d90

Browse files
committed
Update
[ghstack-poisoned]
1 parent e34f0f9 commit 8b25d90

File tree

4 files changed

+336
-297
lines changed

4 files changed

+336
-297
lines changed
Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
torch==2.7.0
2-
transformers==4.52.4
3-
peft==0.15.2
4-
bitsandbytes==0.46.0
5-
datasets==3.6.0
6-
wandb==0.19.11
7-
hydra-core==1.3.2
8-
ray==2.46.0
9-
tqdm==4.67.1
10-
tensordict==0.9.0
11-
vllm==0.9.0.1
12-
accelerate==1.7.0
13-
xformers==0.0.30
1+
vllm==0.11.0
2+
peft
3+
bitsandbytes
4+
datasets
5+
wandb
6+
hydra-core
7+
ray
8+
tqdm
9+
tensordict
10+
accelerate
11+
xformers
Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
torch==2.7.0
2-
transformers==4.52.4
3-
peft==0.15.2
4-
bitsandbytes==0.46.0
5-
datasets==3.6.0
6-
wandb==0.19.11
7-
hydra-core==1.3.2
8-
ray==2.46.0
9-
tqdm==4.67.1
10-
tensordict==0.9.0
11-
vllm==0.9.0.1
12-
accelerate==1.7.0
13-
xformers==0.0.30
14-
nltk==3.9.1
15-
langdetect==1.0.9
16-
immutabledict==4.2.1
1+
vllm==0.11.0
2+
torch
3+
transformers
4+
peft
5+
bitsandbytes
6+
datasets
7+
wandb
8+
hydra-core
9+
ray
10+
tqdm
11+
tensordict
12+
accelerate
13+
xformers
14+
nltk
15+
langdetect
16+
immutabledict

test/llm/ray_helpers.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Helper classes for Ray-based weight synchronization tests.
7+
8+
This module contains Ray actor classes that need to be importable by Ray workers.
9+
These classes are used in test_updaters.py but must be defined at module level
10+
so Ray can serialize and import them on remote workers.
11+
"""
12+
13+
import torch
14+
from torchrl._utils import logger
15+
16+
17+
class WorkerVLLMNCCL:
18+
"""Ray actor for vLLM inference worker (receiver) using NCCL collective communication."""
19+
20+
def __init__(
21+
self,
22+
scheme_config: dict,
23+
model_name: str = "Qwen/Qwen2.5-0.5B",
24+
trainer_actor_name: str = "Trainer",
25+
):
26+
pass
27+
28+
# Store config for deferred initialization
29+
self.scheme_config = scheme_config
30+
self.model_name = model_name
31+
self.trainer_actor_name = trainer_actor_name
32+
self.wrapper = None
33+
self.engine = None
34+
self.receiver = None
35+
self.scheme = None
36+
self.trainer = None
37+
self.model_metadata = None
38+
39+
def setup(self):
40+
"""Set up vLLM engine (deferred from __init__ to avoid blocking)."""
41+
from torchrl.modules.llm.backends import AsyncVLLM
42+
from torchrl.modules.llm.policies import vLLMWrapper
43+
44+
# Create vLLM wrapper
45+
async_engine = AsyncVLLM.from_pretrained(
46+
self.model_name,
47+
num_replicas=2, # Number of engine replicas
48+
)
49+
self.wrapper = vLLMWrapper(async_engine, input_mode="history")
50+
self.engine = self.wrapper.model
51+
52+
# Create scheme from config
53+
from torchrl.weight_update.llm.vllm_nccl import VLLMWeightSyncScheme
54+
55+
self.scheme = VLLMWeightSyncScheme(**self.scheme_config)
56+
57+
# Create receiver (engine handles rank assignment automatically)
58+
self.receiver = self.scheme.create_receiver(self.engine)
59+
return "setup_complete"
60+
61+
def init_metadata(self):
62+
"""Initialize the receiver by fetching metadata from trainer."""
63+
import ray
64+
65+
if self.receiver is None:
66+
raise RuntimeError("Must call setup() before init()")
67+
68+
# Get trainer actor by name
69+
logger.info(f"Getting trainer actor by name {self.trainer_actor_name}")
70+
self.trainer = ray.get_actor(self.trainer_actor_name)
71+
72+
# Fetch model metadata from trainer
73+
logger.info("Fetching model metadata from trainer (requires max_concurrency>1)")
74+
self.model_metadata = ray.get(self.trainer.get_model_metadata.remote())
75+
76+
def init(self):
77+
if self.model_metadata is None:
78+
raise RuntimeError("Must call init_metadata() before init()")
79+
80+
# Initialize receiver with metadata
81+
logger.info("Initializing receiver...")
82+
self.receiver.init_all_workers_group(self.model_metadata)
83+
self.initialized = True
84+
logger.info("Receiver initialized")
85+
return "initialized"
86+
87+
def get_engine(self):
88+
"""Get the vLLM engine reference for RPC coordination."""
89+
if self.engine is None:
90+
raise RuntimeError("Must call setup() first")
91+
return self.engine
92+
93+
def get_sample_output(self):
94+
"""Get a sample output to verify model works."""
95+
# Simple inference test
96+
return "vllm_ready"
97+
98+
@classmethod
99+
def as_remote(cls, *args, **kwargs):
100+
import ray
101+
102+
# No GPUs needed for the actor itself - vLLM workers manage their own placement group (2 GPUs)
103+
# AsyncVLLM service doesn't act as NCCL rank 0 when used with external trainer
104+
return ray.remote(num_cpus=4, num_gpus=0, max_concurrency=4)(cls)
105+
106+
107+
class WorkerTransformerNCCL:
108+
"""Ray actor for transformer trainer (sender) using NCCL collective communication."""
109+
110+
def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
111+
from torchrl.weight_update.llm.vllm_nccl import (
112+
get_model_metadata,
113+
VLLMWeightSyncScheme,
114+
)
115+
from transformers import AutoModelForCausalLM
116+
117+
# Create transformer model
118+
transformer = AutoModelForCausalLM.from_pretrained(
119+
model_name,
120+
dtype=torch.float16,
121+
)
122+
self.transformer = transformer.cuda()
123+
124+
# Create scheme from config
125+
self.scheme = VLLMWeightSyncScheme(**scheme_config)
126+
127+
# Create sender
128+
self.sender = self.scheme.create_sender()
129+
self.sender.register_model(self.transformer)
130+
131+
# Extract and store model metadata
132+
self.model_metadata = get_model_metadata(self.transformer)
133+
134+
def init(self, vllm_engine=None):
135+
"""Initialize sender with optional vLLM engine for RPC coordination.
136+
137+
Args:
138+
vllm_engine: Optional vLLM engine reference for calling collective_rpc
139+
"""
140+
if self.model_metadata is None:
141+
raise RuntimeError("Must call init_metadata() before init()")
142+
143+
self.sender.init_all_workers_group(self.model_metadata, vllm_engine=vllm_engine)
144+
self.initialized = True
145+
logger.info("Trainer initialized")
146+
return "initialized"
147+
148+
def get_model_metadata(self):
149+
"""Get model metadata to share with receiver."""
150+
return self.model_metadata
151+
152+
def update_weights(self, modify_weights: bool = False):
153+
"""Trigger a weight update broadcast.
154+
155+
Args:
156+
modify_weights: If True, modifies weights before broadcasting
157+
for verification purposes.
158+
159+
Returns:
160+
str: "updated" status message
161+
"""
162+
163+
# Optionally modify weights for testing
164+
if modify_weights:
165+
with torch.no_grad():
166+
first_param = next(self.transformer.parameters())
167+
first_param.add_(0.01)
168+
169+
# Broadcast weights to all vLLM workers
170+
self.sender.update_weights()
171+
return "updated"
172+
173+
def get_first_param_sum(self):
174+
"""Get sum of first parameter for verification."""
175+
return next(self.transformer.parameters()).sum().item()
176+
177+
@classmethod
178+
def as_remote(cls, *args, **kwargs):
179+
import ray
180+
181+
return ray.remote(num_cpus=4, num_gpus=1, max_concurrency=4)(cls)
182+
183+
184+
class WorkerVLLMDoubleBuffer:
185+
"""Ray actor for vLLM inference worker (receiver) using double-buffered storage."""
186+
187+
def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
188+
# Store config for deferred initialization
189+
self.scheme_config = scheme_config
190+
self.model_name = model_name
191+
self.wrapper = None
192+
self.engine = None
193+
self.receiver = None
194+
self.scheme = None
195+
196+
def setup(self):
197+
"""Set up vLLM engine and receiver."""
198+
from torchrl.modules.llm.backends import AsyncVLLM
199+
from torchrl.modules.llm.policies import vLLMWrapper
200+
201+
# Create vLLM wrapper
202+
async_engine = AsyncVLLM.from_pretrained(
203+
self.model_name,
204+
num_replicas=1, # Single replica for simplicity
205+
)
206+
self.wrapper = vLLMWrapper(async_engine, input_mode="history")
207+
self.engine = self.wrapper.model
208+
209+
# Create scheme from config
210+
from torchrl.weight_update.llm.vllm_double_buffer import (
211+
VLLMDoubleBufferSyncScheme,
212+
)
213+
214+
self.scheme = VLLMDoubleBufferSyncScheme(**self.scheme_config)
215+
216+
# Create receiver
217+
self.receiver = self.scheme.create_receiver(self.engine)
218+
logger.info("Receiver setup complete")
219+
return "setup_complete"
220+
221+
def poll_and_apply_weights(self):
222+
"""Poll for new weights and apply them to the engine."""
223+
if self.receiver is None:
224+
raise RuntimeError("Must call setup() first")
225+
226+
success = self.receiver.poll_and_apply()
227+
return success
228+
229+
def get_sample_output(self):
230+
"""Get a sample output to verify model works."""
231+
return "vllm_ready"
232+
233+
@classmethod
234+
def as_remote(cls, *args, **kwargs):
235+
import ray
236+
237+
# vLLM worker needs 1 GPU
238+
return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls)
239+
240+
241+
class WorkerTransformerDoubleBuffer:
242+
"""Ray actor for transformer trainer (sender) using double-buffered storage."""
243+
244+
def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
245+
from torchrl.weight_update.llm.vllm_double_buffer import (
246+
VLLMDoubleBufferSyncScheme,
247+
)
248+
from transformers import AutoModelForCausalLM
249+
250+
# Create transformer model
251+
transformer = AutoModelForCausalLM.from_pretrained(
252+
model_name,
253+
dtype=torch.float16,
254+
)
255+
self.transformer = transformer.cuda()
256+
257+
# Create scheme from config
258+
self.scheme = VLLMDoubleBufferSyncScheme(**scheme_config)
259+
260+
# Create sender
261+
self.sender = self.scheme.create_sender()
262+
self.sender.register_model(self.transformer)
263+
logger.info("Trainer setup complete")
264+
265+
def update_weights(self, modify_weights: bool = False):
266+
"""Trigger a weight update by writing to shared storage.
267+
268+
Args:
269+
modify_weights: If True, modifies weights before writing
270+
for verification purposes.
271+
272+
Returns:
273+
str: "updated" status message
274+
"""
275+
# Optionally modify weights for testing
276+
if modify_weights:
277+
with torch.no_grad():
278+
first_param = next(self.transformer.parameters())
279+
first_param.add_(0.01)
280+
281+
# Write weights to shared storage
282+
self.sender.update_weights()
283+
return "updated"
284+
285+
def get_first_param_sum(self):
286+
"""Get sum of first parameter for verification."""
287+
return next(self.transformer.parameters()).sum().item()
288+
289+
@classmethod
290+
def as_remote(cls, *args, **kwargs):
291+
import ray
292+
293+
return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls)

0 commit comments

Comments
 (0)