Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions rllm/engine/agent_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
convert_messages_to_tokens_and_masks,
get_recent_assistant_user_messages,
)
from rllm.engine.rollout.rollout_engine import ModelOutput
from rllm.environments.base.base_env import BaseEnv
from rllm.environments.env_utils import (
compute_mc_return,
Expand Down Expand Up @@ -126,7 +127,7 @@ def __init__(
# Create a thread pool executor for environment interactions (i.e. step, reset, close)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)

async def get_model_response(self, prompt, application_id, **kwargs) -> str:
async def get_model_response(self, prompt, application_id, **kwargs) -> ModelOutput:
"""
Compute model response asynchronously based on the engine type.

Expand Down Expand Up @@ -507,12 +508,9 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text",
assert all(env is not None and isinstance(env, BaseEnv) for env in self.envs), "All environments must be inheriting from BaseEnv"
assert all(env.is_multithread_safe() for env in self.envs), "All environments must be multithread safe for async engine" # type: ignore
max_concurrency = self.n_parallel_agents
if self.engine_name == "verl":
free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False

self.executor = ThreadPoolExecutor(max_workers=max_concurrency)

if self.engine_name == "verl" and free_cache_engine:
if self.engine_name == "verl":
await self.rollout_engine.wake_up() # type: ignore

async def launch_one_trajectory_task(env_idx: int):
Expand Down Expand Up @@ -546,7 +544,7 @@ async def launch_one_trajectory_task(env_idx: int):
except Exception as e:
raise e

if self.engine_name == "verl" and free_cache_engine:
if self.engine_name == "verl":
await self.rollout_engine.sleep() # type: ignore

self.executor.shutdown(wait=False, cancel_futures=True)
Expand Down
21 changes: 9 additions & 12 deletions rllm/engine/agent_workflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
Returns:
DataProto: Transformed results compatible with Verl training.
"""
free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False
if free_cache_engine:
# TODO: later probably should make the `wake_up` and `sleep` methods in base class to be async
if isinstance(self.rollout_engine, VerlEngine):
await self.rollout_engine.wake_up()
else:
self.rollout_engine.wake_up()
if isinstance(self.rollout_engine, VerlEngine):
await self.rollout_engine.wake_up()
else:
# for most other engines, this simply does nothing
self.rollout_engine.wake_up()

if batch.meta_info.get("validate", False):
self.rollout_engine.validate = True
Expand All @@ -180,11 +178,10 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes
self.rollout_engine.validate = False

if free_cache_engine:
if isinstance(self.rollout_engine, VerlEngine):
await self.rollout_engine.sleep()
else:
self.rollout_engine.sleep()
if isinstance(self.rollout_engine, VerlEngine):
await self.rollout_engine.sleep()
else:
self.rollout_engine.sleep()
return self.transform_results_for_verl(results, task_ids)

def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarray) -> "DataProto":
Expand Down
4 changes: 4 additions & 0 deletions rllm/trainer/verl/agent_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def init_workers(self):
**self.config.rllm.agent.get("engine_args", {}),
)

# If `free_cache_engine` is False, we need to manually `sleep` at the start
if self.config.actor_rollout_ref.rollout.get("free_cache_engine", False):
self.async_rollout_manager.sleep()

def init_envs_and_agents(self, batch):
"""
Initialize environment depending on env_class with the necessary extra_info, also set uid of the batch.
Expand Down