@@ -76,6 +76,13 @@ def reset(cls):
7676 cls .NODE_ID = None
7777
7878
79+ def _get_backend (kwargs : dict ):
80+ backend = kwargs .pop ('backend' , None )
81+ if backend is None :
82+ backend = 'cpu:gloo,cuda:nccl' if dist .is_nccl_available () and torch .cuda .is_available () else 'gloo'
83+ return backend
84+
85+
7986def _initialize_via_tcp (
8087 ip : str ,
8188 port : int ,
@@ -95,6 +102,9 @@ def _initialize_via_tcp(
95102 local_world_size = local_world_size ,
96103 node_id = node_id ,
97104 )
105+
106+ backend = _get_backend (kwargs )
107+
98108 msg = f'Connecting via { method } and TCPStore:'
99109 msg += f'\n rank: { _WorkerInfo .RANK } '
100110 msg += f'\n world size: { _WorkerInfo .WORLD_SIZE } '
@@ -103,6 +113,7 @@ def _initialize_via_tcp(
103113 msg += f'\n node id: { _WorkerInfo .NODE_ID } '
104114 msg += f'\n master ip: { ip } '
105115 msg += f'\n master port: { port } '
116+ msg += f'\n backend: { backend } '
106117 print (msg , flush = True )
107118
108119 # TODO: Add check that ip == rank0 host
@@ -118,6 +129,7 @@ def _initialize_via_tcp(
118129 store = store ,
119130 world_size = _WorkerInfo .WORLD_SIZE ,
120131 rank = _WorkerInfo .RANK ,
132+ backend = backend ,
121133 ** kwargs ,
122134 )
123135 if is_root ():
@@ -218,6 +230,7 @@ def _init_process_group_env(**kwargs):
218230 local_rank = int (os .environ ['LOCAL_RANK' ]),
219231 local_world_size = int (os .environ ['LOCAL_WORLD_SIZE' ]),
220232 node_id = int (os .environ ['GROUP_RANK' ]),
233+ ** kwargs ,
221234 )
222235
223236
@@ -235,11 +248,9 @@ def _init_process_group_dummy(**kwargs):
235248 local_world_size = 1 ,
236249 node_id = 0 ,
237250 )
251+ backend = _get_backend (kwargs )
238252
239- backend = kwargs .get ('backend' , None )
240- if backend is None :
241- backend = 'cpu:gloo,cuda:nccl' if dist .is_nccl_available () and torch .cuda .is_available () else 'gloo'
242- print (f'Initializing dummy process group with a single process via HashStore (backend: { backend } )' , flush = True )
253+ print (f'Initializing dummy process group with a single process via HashStore (backend = "{ backend } ")' , flush = True )
243254 store = dist .HashStore ()
244255 dist .init_process_group (store = store , rank = 0 , world_size = 1 , backend = backend , ** kwargs )
245256
@@ -281,12 +292,11 @@ def _init_process_group_slurm(port=DEFAULT_PORT, **kwargs):
281292 )
282293
283294
284- def _init_process_group_auto (verbose = True , ** kwargs ):
295+ def _init_process_group_auto (** kwargs ):
285296 """
286297 Tries to initialize torch.distributed in the following order:
287298 1. If the MASTER_PORT environment variable is set, use environment variable initialization
288299 2. If srun (slurm) was used to launch this program, use slurms environment variables
289- 2. If MPI is available, use MPI to exchange ip addresses (see init_process_group_MPI)
290300 3. Otherwise, a dummy process group with a single process is used (no distributed training)
291301 """
292302
@@ -325,13 +335,13 @@ def init(kind='auto'):
325335 raise ValueError (f"Invalid kind: { kind } . Must be one of 'auto', 'dummy', 'slurm', 'env'" )
326336
327337 if kind == 'auto' :
328- _init_process_group_auto (backend = 'cpu:gloo,cuda:nccl' )
338+ _init_process_group_auto ()
329339 elif kind == 'dummy' :
330- _init_process_group_dummy (backend = 'cpu:gloo,cuda:nccl' )
340+ _init_process_group_dummy ()
331341 elif kind == 'slurm' :
332- _init_process_group_slurm (backend = 'cpu:gloo,cuda:nccl' )
342+ _init_process_group_slurm ()
333343 elif kind == 'env' :
334- _init_process_group_env (backend = 'cpu:gloo,cuda:nccl' )
344+ _init_process_group_env ()
335345
336346 atexit .register (deinitialize_torch_distributed , fail_silently = True )
337347
0 commit comments