Skip to content

Commit 9604f48

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8b25d90 commit 9604f48

File tree

8 files changed

+125
-53
lines changed

8 files changed

+125
-53
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,11 @@ log
189189
Roms
190190

191191
scratch/*
192+
193+
# Large directories from git history that should not be committed
194+
dev/
195+
main/
196+
*.html
197+
198+
# Additional cache directories
199+
.ruff_cache/

sota-implementations/expert-iteration/expert-iteration-async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import hydra
1414

15-
from torchrl import torchrl_logger
15+
from torchrl import merge_ray_runtime_env, torchrl_logger
1616
from torchrl.data.llm.history import History
1717
from torchrl.record.loggers.wandb import WandbLogger
1818
from torchrl.weight_update.llm import get_model_metadata
@@ -397,19 +397,9 @@ def main(cfg):
397397
if not k.startswith("_")
398398
}
399399

400-
# Add computed GPU configuration
400+
# Add computed GPU configuration and merge with default runtime_env
401401
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
402-
# Ensure runtime_env and env_vars exist
403-
if "runtime_env" not in ray_init_config:
404-
ray_init_config["runtime_env"] = {}
405-
if not isinstance(ray_init_config["runtime_env"], dict):
406-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
407-
if "env_vars" not in ray_init_config["runtime_env"]:
408-
ray_init_config["runtime_env"]["env_vars"] = {}
409-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
410-
ray_init_config["runtime_env"]["env_vars"] = dict(
411-
ray_init_config["runtime_env"]["env_vars"]
412-
)
402+
ray_init_config = merge_ray_runtime_env(ray_init_config)
413403
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
414404
ray.init(**ray_init_config)
415405

sota-implementations/expert-iteration/expert-iteration-sync.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import hydra
1414

15-
from torchrl import torchrl_logger
15+
from torchrl import merge_ray_runtime_env, torchrl_logger
1616
from torchrl.data.llm.history import History
1717
from torchrl.record.loggers.wandb import WandbLogger
1818
from torchrl.weight_update.llm import get_model_metadata
@@ -398,19 +398,9 @@ def main(cfg):
398398
if not k.startswith("_")
399399
}
400400

401-
# Add computed GPU configuration
401+
# Add computed GPU configuration and merge with default runtime_env
402402
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
403-
# Ensure runtime_env and env_vars exist
404-
if "runtime_env" not in ray_init_config:
405-
ray_init_config["runtime_env"] = {}
406-
if not isinstance(ray_init_config["runtime_env"], dict):
407-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
408-
if "env_vars" not in ray_init_config["runtime_env"]:
409-
ray_init_config["runtime_env"]["env_vars"] = {}
410-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
411-
ray_init_config["runtime_env"]["env_vars"] = dict(
412-
ray_init_config["runtime_env"]["env_vars"]
413-
)
403+
ray_init_config = merge_ray_runtime_env(ray_init_config)
414404
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
415405
ray.init(**ray_init_config)
416406

sota-implementations/grpo/grpo-async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import hydra
1515

16-
from torchrl import torchrl_logger
16+
from torchrl import merge_ray_runtime_env, torchrl_logger
1717
from torchrl.data.llm.history import History
1818
from torchrl.record.loggers.wandb import WandbLogger
1919
from torchrl.weight_update.llm import get_model_metadata
@@ -319,19 +319,9 @@ def main(cfg):
319319
if not k.startswith("_")
320320
}
321321

322-
# Add computed GPU configuration
322+
# Add computed GPU configuration and merge with default runtime_env
323323
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
324-
# Ensure runtime_env and env_vars exist
325-
if "runtime_env" not in ray_init_config:
326-
ray_init_config["runtime_env"] = {}
327-
if not isinstance(ray_init_config["runtime_env"], dict):
328-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
329-
if "env_vars" not in ray_init_config["runtime_env"]:
330-
ray_init_config["runtime_env"]["env_vars"] = {}
331-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
332-
ray_init_config["runtime_env"]["env_vars"] = dict(
333-
ray_init_config["runtime_env"]["env_vars"]
334-
)
324+
ray_init_config = merge_ray_runtime_env(ray_init_config)
335325
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
336326
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
337327
if ray_managed_externally:

sota-implementations/grpo/grpo-sync.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import hydra
1414

15-
from torchrl import torchrl_logger
15+
from torchrl import merge_ray_runtime_env, torchrl_logger
1616
from torchrl.data.llm.history import History
1717
from torchrl.record.loggers.wandb import WandbLogger
1818
from torchrl.weight_update.llm import get_model_metadata
@@ -319,19 +319,9 @@ def main(cfg):
319319
if not k.startswith("_")
320320
}
321321

322-
# Add computed GPU configuration
322+
# Add computed GPU configuration and merge with default runtime_env
323323
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
324-
# Ensure runtime_env and env_vars exist
325-
if "runtime_env" not in ray_init_config:
326-
ray_init_config["runtime_env"] = {}
327-
if not isinstance(ray_init_config["runtime_env"], dict):
328-
ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"])
329-
if "env_vars" not in ray_init_config["runtime_env"]:
330-
ray_init_config["runtime_env"]["env_vars"] = {}
331-
if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict):
332-
ray_init_config["runtime_env"]["env_vars"] = dict(
333-
ray_init_config["runtime_env"]["env_vars"]
334-
)
324+
ray_init_config = merge_ray_runtime_env(ray_init_config)
335325
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
336326
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
337327
if ray_managed_externally:

test/llm/test_updaters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_open_port():
4141
if _has_ray:
4242
import ray
4343

44-
from .ray_helpers import (
44+
from ray_helpers import (
4545
WorkerTransformerDoubleBuffer,
4646
WorkerTransformerNCCL,
4747
WorkerVLLMDoubleBuffer,

torchrl/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@
5454
from torchrl._utils import (
5555
auto_unwrap_transformed_env,
5656
compile_with_warmup,
57+
get_ray_default_runtime_env,
5758
implement_for,
5859
logger,
60+
merge_ray_runtime_env,
5961
set_auto_unwrap_transformed_env,
6062
timeit,
6163
)
@@ -113,7 +115,9 @@ def _inv(self):
113115
__all__ = [
114116
"auto_unwrap_transformed_env",
115117
"compile_with_warmup",
118+
"get_ray_default_runtime_env",
116119
"implement_for",
120+
"merge_ray_runtime_env",
117121
"set_auto_unwrap_transformed_env",
118122
"timeit",
119123
"logger",

torchrl/_utils.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,103 @@ def as_remote(cls, remote_config: dict[str, Any] | None = None):
962962
remote_collector = ray.remote(**remote_config)(cls)
963963
remote_collector.is_remote = True
964964
return remote_collector
965+
966+
967+
def get_ray_default_runtime_env() -> dict[str, Any]:
968+
"""Get the default Ray runtime environment configuration for TorchRL.
969+
970+
This function returns a runtime environment configuration that excludes
971+
large directories and files that should not be uploaded to Ray workers.
972+
This helps prevent issues with Ray's working_dir size limits (512MB default).
973+
974+
Returns:
975+
dict: A dictionary containing the default runtime_env configuration with
976+
excludes for common large directories.
977+
978+
Examples:
979+
>>> import ray
980+
>>> from torchrl._utils import get_ray_default_runtime_env
981+
>>> ray_init_config = {"num_cpus": 4}
982+
>>> ray_init_config["runtime_env"] = get_ray_default_runtime_env()
983+
>>> ray.init(**ray_init_config)
984+
985+
Note:
986+
The excludes list includes:
987+
- Virtual environments (.venv/, venv/, etc.)
988+
- Test files and caches
989+
- Documentation builds
990+
- Benchmarks
991+
- Examples and tutorials
992+
- CI/CD configurations
993+
- IDE configurations
994+
995+
"""
996+
return {
997+
"excludes": [
998+
".venv/",
999+
"venv/",
1000+
"env/",
1001+
"ENV/",
1002+
"env.bak/",
1003+
"venv.bak/",
1004+
"test/",
1005+
"tests/",
1006+
"docs/",
1007+
"benchmarks/",
1008+
"tutorials/",
1009+
"examples/",
1010+
".github/",
1011+
".pytest_cache/",
1012+
".mypy_cache/",
1013+
".ruff_cache/",
1014+
"__pycache__/",
1015+
"*.pyc",
1016+
"*.pyo",
1017+
"*.egg-info/",
1018+
".idea/",
1019+
".vscode/",
1020+
"dev/",
1021+
"main/",
1022+
"*.html",
1023+
]
1024+
}
1025+
1026+
1027+
def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]:
1028+
"""Merge user-provided ray_init_config with default runtime_env excludes.
1029+
1030+
This function ensures that the default TorchRL runtime_env excludes are applied
1031+
to prevent large directories from being uploaded to Ray workers, while preserving
1032+
any user-provided configuration.
1033+
1034+
Args:
1035+
ray_init_config (dict): The ray init configuration dictionary to merge.
1036+
1037+
Returns:
1038+
dict: The merged configuration with default runtime_env excludes applied.
1039+
1040+
Examples:
1041+
>>> from torchrl._utils import merge_ray_runtime_env
1042+
>>> ray_init_config = {"num_cpus": 4}
1043+
>>> ray_init_config = merge_ray_runtime_env(ray_init_config)
1044+
>>> ray.init(**ray_init_config)
1045+
1046+
"""
1047+
default_runtime_env = get_ray_default_runtime_env()
1048+
runtime_env = ray_init_config.setdefault("runtime_env", {})
1049+
1050+
if not isinstance(runtime_env, dict):
1051+
runtime_env = dict(runtime_env)
1052+
ray_init_config["runtime_env"] = runtime_env
1053+
1054+
# Merge excludes lists
1055+
excludes = runtime_env.get("excludes", [])
1056+
runtime_env["excludes"] = list(set(default_runtime_env["excludes"] + excludes))
1057+
1058+
# Ensure env_vars exists
1059+
if "env_vars" not in runtime_env:
1060+
runtime_env["env_vars"] = {}
1061+
elif not isinstance(runtime_env["env_vars"], dict):
1062+
runtime_env["env_vars"] = dict(runtime_env["env_vars"])
1063+
1064+
return ray_init_config

0 commit comments

Comments
 (0)