diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index 7033a8c02..7ed14a044 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -13,6 +13,7 @@ convert_messages_to_tokens_and_masks, get_recent_assistant_user_messages, ) +from rllm.engine.rollout.rollout_engine import ModelOutput from rllm.environments.base.base_env import BaseEnv from rllm.environments.env_utils import ( compute_mc_return, @@ -126,7 +127,7 @@ def __init__( # Create a thread pool executor for environment interactions (i.e. step, reset, close) self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - async def get_model_response(self, prompt, application_id, **kwargs) -> str: + async def get_model_response(self, prompt, application_id, **kwargs) -> ModelOutput: """ Compute model response asynchronously based on the engine type. @@ -507,12 +508,9 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text", assert all(env is not None and isinstance(env, BaseEnv) for env in self.envs), "All environments must be inheriting from BaseEnv" assert all(env.is_multithread_safe() for env in self.envs), "All environments must be multithread safe for async engine" # type: ignore max_concurrency = self.n_parallel_agents - if self.engine_name == "verl": - free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False - self.executor = ThreadPoolExecutor(max_workers=max_concurrency) - if self.engine_name == "verl" and free_cache_engine: + if self.engine_name == "verl": await self.rollout_engine.wake_up() # type: ignore async def launch_one_trajectory_task(env_idx: int): @@ -546,7 +544,7 @@ async def launch_one_trajectory_task(env_idx: int): except Exception as e: raise e - if self.engine_name == "verl" and free_cache_engine: + if self.engine_name == "verl": await self.rollout_engine.sleep() # type: ignore self.executor.shutdown(wait=False, cancel_futures=True) diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index ff2aa1fbc..0099c3aa4 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -165,13 +165,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": Returns: DataProto: Transformed results compatible with Verl training. """ - free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False - if free_cache_engine: - # TODO: later probably should make the `wake_up` and `sleep` methods in base class to be async - if isinstance(self.rollout_engine, VerlEngine): - await self.rollout_engine.wake_up() - else: - self.rollout_engine.wake_up() + if isinstance(self.rollout_engine, VerlEngine): + await self.rollout_engine.wake_up() + else: + # for most other engines, this simply does nothing + self.rollout_engine.wake_up() if batch.meta_info.get("validate", False): self.rollout_engine.validate = True @@ -180,11 +178,10 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes self.rollout_engine.validate = False - if free_cache_engine: - if isinstance(self.rollout_engine, VerlEngine): - await self.rollout_engine.sleep() - else: - self.rollout_engine.sleep() + if isinstance(self.rollout_engine, VerlEngine): + await self.rollout_engine.sleep() + else: + self.rollout_engine.sleep() return self.transform_results_for_verl(results, task_ids) def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarray) -> "DataProto": diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index ae6650590..6ce2f4df6 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -82,6 +82,10 @@ def init_workers(self): **self.config.rllm.agent.get("engine_args", {}), ) + # If `free_cache_engine` is False, we need to manually `sleep` at the start + if self.config.actor_rollout_ref.rollout.get("free_cache_engine", False): + self.async_rollout_manager.sleep() + def init_envs_and_agents(self, batch): """ Initialize environment depending on env_class with the necessary extra_info, also set uid of the batch.