Skip to content

Commit e2b9240

Browse files
authored
fix controlling the n_parallel_agents and the concurrent env operations (#271)
* fix controlling the n_parallel_agents and the concurrent env operations * fix controlling the n_parallel_agents and the concurrent env operations * fix controlling the n_parallel_agents and the concurrent env operations * applied pre-commit, fixed unused-import * Added ThreadPoolExecutor in excute_tasks(); if not hasattr(self, "executor") * renaming max_workers to max_env_workers
1 parent af240fc commit e2b9240

File tree

1 file changed

+27
-24
lines changed

1 file changed

+27
-24
lines changed

rllm/engine/agent_execution_engine.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import concurrent.futures
32
import logging
43
import time
54
import traceback
@@ -31,7 +30,7 @@ def __init__(
3130
tokenizer=None,
3231
rollout_engine=None,
3332
chat_parser=None,
34-
n_parallel_agents=1,
33+
n_parallel_agents=1, # The number of active agents
3534
trajectory_timeout=None,
3635
gamma=0.2,
3736
api_retries=3,
@@ -45,7 +44,7 @@ def __init__(
4544
agent_args=None,
4645
rollout_engine_args=None,
4746
env_args=None,
48-
max_workers=64,
47+
max_workers=64, # The number of concurrent env operations
4948
enforce_max_prompt_length=False, # If enabled, applies max_prompt check per step
5049
overlong_filter=False, # Filter for overlong trajectories (i.e. TRUNCATION, MAX_STEPS, TIMEOUT)
5150
**kwargs,
@@ -61,6 +60,7 @@ def __init__(
6160
self.tokenizer = tokenizer
6261
self.engine_name = engine_name
6362
self.n_parallel_agents = n_parallel_agents
63+
self.max_env_workers = max_workers
6464
self.overlong_filter = overlong_filter
6565

6666
# For interaction
@@ -117,9 +117,6 @@ def __init__(
117117
disable_thinking=self.disable_thinking,
118118
)
119119

120-
# Create a thread pool executor for environment interactions (i.e. step, reset, close)
121-
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
122-
123120
async def get_model_response(self, prompt, application_id, **kwargs) -> str:
124121
"""
125122
Compute model response asynchronously based on the engine type.
@@ -167,7 +164,6 @@ def update_envs_and_agents(self, envs, agents):
167164
for idx, env in enumerate(envs):
168165
env.idx = idx
169166
self.agents = agents
170-
self.n_parallel_agents = len(envs)
171167

172168
async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Text", **kwargs):
173169
"""Run a single agent's trajectory asynchronously"""
@@ -426,28 +422,30 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text",
426422
timing_raw = {}
427423
assert all(env is not None and isinstance(env, BaseEnv) for env in self.envs), "All environments must be inheriting from BaseEnv"
428424
assert all(env.is_multithread_safe() for env in self.envs), "All environments must be multithread safe for async engine" # type: ignore
429-
max_concurrency = self.n_parallel_agents
430-
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
425+
if not hasattr(self, "executor") or self.executor._shutdown:
426+
self.executor = ThreadPoolExecutor(max_workers=self.max_env_workers)
427+
semaphore = asyncio.Semaphore(self.n_parallel_agents)
431428

432429
if self.engine_name == "verl":
433430
self.rollout_engine.wake_up()
434431

435432
async def launch_one_trajectory_task(env_idx: int):
436-
try:
437-
application_id = str(uuid.uuid4())
438-
result = await self.run_agent_trajectory_with_retry(
439-
idx=env_idx,
440-
application_id=application_id,
441-
seed=reset_seed,
442-
mode=mode,
443-
**kwargs,
444-
)
445-
except Exception as e:
446-
import traceback
447-
448-
traceback.print_exc()
449-
raise e
450-
return result
433+
async with semaphore:
434+
try:
435+
application_id = str(uuid.uuid4())
436+
result = await self.run_agent_trajectory_with_retry(
437+
idx=env_idx,
438+
application_id=application_id,
439+
seed=reset_seed,
440+
mode=mode,
441+
**kwargs,
442+
)
443+
except Exception as e:
444+
import traceback
445+
446+
traceback.print_exc()
447+
raise e
448+
return result
451449

452450
# Create all N conceptual tasks. Their execution will be throttled by the semaphore
453451
# and the availability of agent/env indices.
@@ -480,6 +478,8 @@ async def execute_tasks(self, tasks: list[dict]):
480478
Returns:
481479
A list of trajectories, one for each task.
482480
"""
481+
if not hasattr(self, "executor") or self.executor._shutdown:
482+
self.executor = ThreadPoolExecutor(max_workers=self.max_env_workers)
483483

484484
max_concurrent = self.n_parallel_agents
485485

@@ -521,6 +521,9 @@ async def sem_wrapper(task_id, task):
521521

522522
all_trajectories = {task_id: trajectory for task_id, trajectory in results}
523523
ordered_trajectories = [all_trajectories[i] for i in range(len(all_trajectories))]
524+
525+
self.executor.shutdown(wait=False, cancel_futures=True)
526+
524527
return ordered_trajectories
525528

526529
def shutdown(self):

0 commit comments

Comments
 (0)