From fc52f8bdda1cbcb908b677359150b8e82b5eee54 Mon Sep 17 00:00:00 2001 From: listar2000 Date: Fri, 24 Oct 2025 16:05:09 -0500 Subject: [PATCH 1/3] fix agent workflow trainer and engine --- rllm/engine/agent_workflow_engine.py | 14 ++++++++++---- rllm/trainer/verl/agent_workflow_trainer.py | 11 +++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index fa080f5c5..fc106d489 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -10,7 +10,8 @@ from tqdm import tqdm from rllm.agents.agent import Episode -from rllm.engine.rollout import ModelOutput, RolloutEngine +from rllm.engine.rollout import ModelOutput +from rllm.engine.rollout.verl_engine import VerlEngine from rllm.misc import colorful_print from rllm.workflows.workflow import TerminationReason, Workflow @@ -22,7 +23,7 @@ class AgentWorkflowEngine: - def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs): + def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: VerlEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs): """Initialize the AgentWorkflowEngine. Args: @@ -164,14 +165,19 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": Returns: DataProto: Transformed results compatible with Verl training. """ - self.rollout_engine.wake_up() + free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False + if free_cache_engine: + await self.rollout_engine.wake_up() + if batch.meta_info.get("validate", False): self.rollout_engine.validate = True tasks = batch.non_tensor_batch["extra_info"].tolist() task_ids = batch.non_tensor_batch["task_ids"].tolist() results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes self.rollout_engine.validate = False - self.rollout_engine.sleep() + + if free_cache_engine: + await self.rollout_engine.sleep() return self.transform_results_for_verl(results, task_ids) def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarray) -> "DataProto": diff --git a/rllm/trainer/verl/agent_workflow_trainer.py b/rllm/trainer/verl/agent_workflow_trainer.py index f0ef95f3d..296a2c9d7 100644 --- a/rllm/trainer/verl/agent_workflow_trainer.py +++ b/rllm/trainer/verl/agent_workflow_trainer.py @@ -53,23 +53,24 @@ def __init__( self.workflow_class = workflow_class self.workflow_args = workflow_args or {} + self._validate_config() self._loop = asyncio.new_event_loop() self._thread = threading.Thread(target=self._loop.run_forever, daemon=True) self._thread.start() def _validate_config(self): + assert self.workflow_class is not None, "workflow_class is required for agent workflow trainer" assert self.config.actor_rollout_ref.hybrid_engine is True, "Only hybrid engine is supported" assert self.config.actor_rollout_ref.rollout.mode == "async", "Only async rollout mode is supported" assert self.use_rm is False, "Reward models are not supported. Rewards should be assigned using a reward function in the workflow or environment." if self.config.rllm.rejection_sample.multiplier != 1: assert self.config.rllm.rejection_sample.enable is True, "rejection sampling is disabled, but rejection_sample.multiplier is not 1" + # TODO: revisit whether this is now supported by Verl if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: raise NotImplementedError("REMAX is not supported yet") - super()._validate_config() - def init_workers(self): super().init_workers() @@ -136,7 +137,7 @@ def fit_agent(self): for epoch in range(self.config.trainer.total_epochs): pprint(f"epoch {epoch}, step {self.global_steps} started") for batch_dict in self.train_dataloader: - do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False + do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.get("profile_steps") is not None else False with marked_timer("start_profile", timing_raw): self._start_profiling(do_profile) @@ -642,7 +643,9 @@ def _remove_padding(self, batch): def shutdown(self): """A cleanup method to gracefully stop the background event loop.""" - self.agent_execution_engine.shutdown() + if hasattr(self, "agent_execution_engine") and self.agent_execution_engine is not None: + self.agent_execution_engine.shutdown() + self.agent_execution_engine = None if hasattr(self, "_loop") and self._loop is not None and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) if hasattr(self, "_thread") and self._thread is not None: From 3194183099c0b02e0dc4bc71205dd898aeb581de Mon Sep 17 00:00:00 2001 From: listar2000 Date: Sun, 26 Oct 2025 07:55:04 -0500 Subject: [PATCH 2/3] add backward compatibility --- rllm/engine/agent_workflow_engine.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index fc106d489..ff2aa1fbc 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -10,7 +10,7 @@ from tqdm import tqdm from rllm.agents.agent import Episode -from rllm.engine.rollout import ModelOutput +from rllm.engine.rollout import ModelOutput, RolloutEngine from rllm.engine.rollout.verl_engine import VerlEngine from rllm.misc import colorful_print from rllm.workflows.workflow import TerminationReason, Workflow @@ -23,7 +23,7 @@ class AgentWorkflowEngine: - def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: VerlEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs): + def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs): """Initialize the AgentWorkflowEngine. Args: @@ -167,7 +167,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": """ free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False if free_cache_engine: - await self.rollout_engine.wake_up() + # TODO: later probably should make the `wake_up` and `sleep` methods in base class to be async + if isinstance(self.rollout_engine, VerlEngine): + await self.rollout_engine.wake_up() + else: + self.rollout_engine.wake_up() if batch.meta_info.get("validate", False): self.rollout_engine.validate = True @@ -177,7 +181,10 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": self.rollout_engine.validate = False if free_cache_engine: - await self.rollout_engine.sleep() + if isinstance(self.rollout_engine, VerlEngine): + await self.rollout_engine.sleep() + else: + self.rollout_engine.sleep() return self.transform_results_for_verl(results, task_ids) def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarray) -> "DataProto": From 84507ca5fe87e896ed98147a9a49d1ca1a95a126 Mon Sep 17 00:00:00 2001 From: listar2000 Date: Sat, 8 Nov 2025 18:43:06 -0600 Subject: [PATCH 3/3] propose fix --- rllm/engine/agent_execution_engine.py | 10 ++++------ rllm/engine/agent_workflow_engine.py | 21 +++++++++------------ rllm/trainer/verl/agent_ppo_trainer.py | 4 ++++ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index 7033a8c02..7ed14a044 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -13,6 +13,7 @@ convert_messages_to_tokens_and_masks, get_recent_assistant_user_messages, ) +from rllm.engine.rollout.rollout_engine import ModelOutput from rllm.environments.base.base_env import BaseEnv from rllm.environments.env_utils import ( compute_mc_return, @@ -126,7 +127,7 @@ def __init__( # Create a thread pool executor for environment interactions (i.e. step, reset, close) self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - async def get_model_response(self, prompt, application_id, **kwargs) -> str: + async def get_model_response(self, prompt, application_id, **kwargs) -> ModelOutput: """ Compute model response asynchronously based on the engine type. @@ -507,12 +508,9 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text", assert all(env is not None and isinstance(env, BaseEnv) for env in self.envs), "All environments must be inheriting from BaseEnv" assert all(env.is_multithread_safe() for env in self.envs), "All environments must be multithread safe for async engine" # type: ignore max_concurrency = self.n_parallel_agents - if self.engine_name == "verl": - free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False - self.executor = ThreadPoolExecutor(max_workers=max_concurrency) - if self.engine_name == "verl" and free_cache_engine: + if self.engine_name == "verl": await self.rollout_engine.wake_up() # type: ignore async def launch_one_trajectory_task(env_idx: int): @@ -546,7 +544,7 @@ async def launch_one_trajectory_task(env_idx: int): except Exception as e: raise e - if self.engine_name == "verl" and free_cache_engine: + if self.engine_name == "verl": await self.rollout_engine.sleep() # type: ignore self.executor.shutdown(wait=False, cancel_futures=True) diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index ff2aa1fbc..0099c3aa4 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -165,13 +165,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": Returns: DataProto: Transformed results compatible with Verl training. """ - free_cache_engine = self.config.actor_rollout_ref.rollout.free_cache_engine if self.config else False - if free_cache_engine: - # TODO: later probably should make the `wake_up` and `sleep` methods in base class to be async - if isinstance(self.rollout_engine, VerlEngine): - await self.rollout_engine.wake_up() - else: - self.rollout_engine.wake_up() + if isinstance(self.rollout_engine, VerlEngine): + await self.rollout_engine.wake_up() + else: + # for most other engines, this simply does nothing + self.rollout_engine.wake_up() if batch.meta_info.get("validate", False): self.rollout_engine.validate = True @@ -180,11 +178,10 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes self.rollout_engine.validate = False - if free_cache_engine: - if isinstance(self.rollout_engine, VerlEngine): - await self.rollout_engine.sleep() - else: - self.rollout_engine.sleep() + if isinstance(self.rollout_engine, VerlEngine): + await self.rollout_engine.sleep() + else: + self.rollout_engine.sleep() return self.transform_results_for_verl(results, task_ids) def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarray) -> "DataProto": diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index ae6650590..6ce2f4df6 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -82,6 +82,10 @@ def init_workers(self): **self.config.rllm.agent.get("engine_args", {}), ) + # If `free_cache_engine` is False, we need to manually `sleep` at the start + if self.config.actor_rollout_ref.rollout.get("free_cache_engine", False): + self.async_rollout_manager.sleep() + def init_envs_and_agents(self, batch): """ Initialize environment depending on env_class with the necessary extra_info, also set uid of the batch.