From a4a93995a1b3c11119e7543e8cd0c1bd70593921 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 21 Oct 2025 13:31:14 -0700 Subject: [PATCH 1/4] create task to ensure parallel execution --- src/forge/actors/generator.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 6c2efd5e6..dbdd5abf2 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -270,11 +270,16 @@ def split_keys(keys): return [keys[i::n_fetchers] for i in range(n_fetchers)] futures = [] - for i, names in enumerate(split_keys(hf_param_names)): - fut = self.weight_fetchers.slice(procs=i).fetch.call_one( - version=version, param_names=names - ) - futures.append(fut) + async with asyncio.TaskGroup() as tg: + for i, names in enumerate(split_keys(hf_param_names)): + + async def fetch_coro(): + return self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) + + fut = tg.create_task(fetch_coro()) + futures.append(fut) sub_state_dicts = [await fut for fut in futures] From 838a4a9c138fe328671054d282b5134ddc336640 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 21 Oct 2025 13:37:37 -0700 Subject: [PATCH 2/4] remove task group since it doesn't work with python 3.10 --- src/forge/actors/generator.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index dbdd5abf2..369380fba 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -270,16 +270,15 @@ def split_keys(keys): return [keys[i::n_fetchers] for i in range(n_fetchers)] futures = [] - async with asyncio.TaskGroup() as tg: - for i, names in enumerate(split_keys(hf_param_names)): + for i, names in enumerate(split_keys(hf_param_names)): - async def fetch_coro(): - return self.weight_fetchers.slice(procs=i).fetch.call_one( - version=version, param_names=names - ) + async def fetch_coro(): + return self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) - fut = tg.create_task(fetch_coro()) - futures.append(fut) + fut = asyncio.create_task(fetch_coro()) + futures.append(fut) sub_state_dicts = [await fut for fut in futures] From 6d5df6a26cc2968be307dbc02f908fa8b1854441 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 21 Oct 2025 15:07:05 -0700 Subject: [PATCH 3/4] set flag instead --- src/forge/actors/generator.py | 8 +++----- src/forge/env.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 369380fba..9262c9aac 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -272,12 +272,10 @@ def split_keys(keys): futures = [] for i, names in enumerate(split_keys(hf_param_names)): - async def fetch_coro(): - return self.weight_fetchers.slice(procs=i).fetch.call_one( - version=version, param_names=names - ) + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) - fut = asyncio.create_task(fetch_coro()) futures.append(fut) sub_state_dicts = [await fut for fut in futures] diff --git a/src/forge/env.py b/src/forge/env.py index b698b8013..989cacba5 100644 --- a/src/forge/env.py +++ b/src/forge/env.py @@ -105,6 +105,17 @@ def get_value(self) -> Any: description="Whether or not to use RDMA in TorchStore.", ) +MONARCH_OLD_ASYNC_WORKAROUND = EnvVar( + name="MONARCH_OLD_ASYNC_WORKAROUND", + default=1, + description=( + "If enabled, monarch messages will be sent immediately even it's not" + " awaited. This is needed for parallel fetching of weights, as using" + " create_task creates race condition. This is a temporary workaround" + " and will be removed once we have a better solution." + ), +) + def all_env_vars() -> list[EnvVar]: """Retrieves all registered environment variable names.""" From 9c0068243e9574f96a6203aacbe22976b40892d3 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 21 Oct 2025 15:08:05 -0700 Subject: [PATCH 4/4] remove empty lines --- src/forge/actors/generator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 9262c9aac..6c2efd5e6 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -271,11 +271,9 @@ def split_keys(keys): futures = [] for i, names in enumerate(split_keys(hf_param_names)): - fut = self.weight_fetchers.slice(procs=i).fetch.call_one( version=version, param_names=names ) - futures.append(fut) sub_state_dicts = [await fut for fut in futures]