11import asyncio
2- import concurrent .futures
32import logging
43import time
54import traceback
@@ -31,7 +30,7 @@ def __init__(
3130 tokenizer = None ,
3231 rollout_engine = None ,
3332 chat_parser = None ,
34- n_parallel_agents = 1 ,
33+ n_parallel_agents = 1 , # The number of active agents
3534 trajectory_timeout = None ,
3635 gamma = 0.2 ,
3736 api_retries = 3 ,
@@ -45,7 +44,7 @@ def __init__(
4544 agent_args = None ,
4645 rollout_engine_args = None ,
4746 env_args = None ,
48- max_workers = 64 ,
47+ max_workers = 64 , # The number of concurrent env operations
4948 enforce_max_prompt_length = False , # If enabled, applies max_prompt check per step
5049 overlong_filter = False , # Filter for overlong trajectories (i.e. TRUNCATION, MAX_STEPS, TIMEOUT)
5150 ** kwargs ,
@@ -61,6 +60,7 @@ def __init__(
6160 self .tokenizer = tokenizer
6261 self .engine_name = engine_name
6362 self .n_parallel_agents = n_parallel_agents
63+ self .max_env_workers = max_workers
6464 self .overlong_filter = overlong_filter
6565
6666 # For interaction
@@ -117,9 +117,6 @@ def __init__(
117117 disable_thinking = self .disable_thinking ,
118118 )
119119
120- # Create a thread pool executor for environment interactions (i.e. step, reset, close)
121- self .executor = concurrent .futures .ThreadPoolExecutor (max_workers = max_workers )
122-
123120 async def get_model_response (self , prompt , application_id , ** kwargs ) -> str :
124121 """
125122 Compute model response asynchronously based on the engine type.
@@ -167,7 +164,6 @@ def update_envs_and_agents(self, envs, agents):
167164 for idx , env in enumerate (envs ):
168165 env .idx = idx
169166 self .agents = agents
170- self .n_parallel_agents = len (envs )
171167
172168 async def run_agent_trajectory_async (self , idx , application_id , seed = 0 , mode = "Text" , ** kwargs ):
173169 """Run a single agent's trajectory asynchronously"""
@@ -426,28 +422,30 @@ async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text",
426422 timing_raw = {}
427423 assert all (env is not None and isinstance (env , BaseEnv ) for env in self .envs ), "All environments must be inheriting from BaseEnv"
428424 assert all (env .is_multithread_safe () for env in self .envs ), "All environments must be multithread safe for async engine" # type: ignore
429- max_concurrency = self .n_parallel_agents
430- self .executor = ThreadPoolExecutor (max_workers = max_concurrency )
425+ if not hasattr (self , "executor" ) or self .executor ._shutdown :
426+ self .executor = ThreadPoolExecutor (max_workers = self .max_env_workers )
427+ semaphore = asyncio .Semaphore (self .n_parallel_agents )
431428
432429 if self .engine_name == "verl" :
433430 self .rollout_engine .wake_up ()
434431
435432 async def launch_one_trajectory_task (env_idx : int ):
436- try :
437- application_id = str (uuid .uuid4 ())
438- result = await self .run_agent_trajectory_with_retry (
439- idx = env_idx ,
440- application_id = application_id ,
441- seed = reset_seed ,
442- mode = mode ,
443- ** kwargs ,
444- )
445- except Exception as e :
446- import traceback
447-
448- traceback .print_exc ()
449- raise e
450- return result
433+ async with semaphore :
434+ try :
435+ application_id = str (uuid .uuid4 ())
436+ result = await self .run_agent_trajectory_with_retry (
437+ idx = env_idx ,
438+ application_id = application_id ,
439+ seed = reset_seed ,
440+ mode = mode ,
441+ ** kwargs ,
442+ )
443+ except Exception as e :
444+ import traceback
445+
446+ traceback .print_exc ()
447+ raise e
448+ return result
451449
452450 # Create all N conceptual tasks. Their execution will be throttled by the semaphore
453451 # and the availability of agent/env indices.
@@ -480,6 +478,8 @@ async def execute_tasks(self, tasks: list[dict]):
480478 Returns:
481479 A list of trajectories, one for each task.
482480 """
481+ if not hasattr (self , "executor" ) or self .executor ._shutdown :
482+ self .executor = ThreadPoolExecutor (max_workers = self .max_env_workers )
483483
484484 max_concurrent = self .n_parallel_agents
485485
@@ -521,6 +521,9 @@ async def sem_wrapper(task_id, task):
521521
522522 all_trajectories = {task_id : trajectory for task_id , trajectory in results }
523523 ordered_trajectories = [all_trajectories [i ] for i in range (len (all_trajectories ))]
524+
525+ self .executor .shutdown (wait = False , cancel_futures = True )
526+
524527 return ordered_trajectories
525528
526529 def shutdown (self ):
0 commit comments