@@ -61,6 +61,7 @@ def __init__(
6161 max_retries : int = 0 ,
6262 mounts : Optional [List [str ]] = None ,
6363 rdzv_port : int = 29500 ,
64+ rdzv_backend : str = None ,
6465 scheduler_args : Optional [Dict [str , str ]] = None ,
6566 image : Optional [str ] = None ,
6667 ):
@@ -81,6 +82,7 @@ def __init__(
8182 self .max_retries = max_retries
8283 self .mounts : List [str ] = mounts if mounts is not None else []
8384 self .rdzv_port = rdzv_port
85+ self .rdzv_backend = rdzv_backend
8486 self .scheduler_args : Dict [str , str ] = (
8587 scheduler_args if scheduler_args is not None else dict ()
8688 )
@@ -104,6 +106,9 @@ def _dry_run(self, cluster: "Cluster"):
104106 env = self .env ,
105107 max_retries = self .max_retries ,
106108 rdzv_port = self .rdzv_port ,
109+ rdzv_backend = self .rdzv_backend
110+ if self .rdzv_backend is not None
111+ else "static" ,
107112 mounts = self .mounts ,
108113 ),
109114 scheduler = cluster .torchx_scheduler ,
@@ -142,6 +147,9 @@ def _dry_run_no_cluster(self):
142147 env = self .env , # should this still exist?
143148 max_retries = self .max_retries ,
144149 rdzv_port = self .rdzv_port , # should this still exist?
150+ rdzv_backend = self .rdzv_backend
151+ if self .rdzv_backend is not None
152+ else "c10d" ,
145153 mounts = self .mounts ,
146154 image = self .image
147155 if self .image is not None
0 commit comments