1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10import datetime 11import logging 12from typing import cast, Optional 13 14from torch.distributed import PrefixStore, Store, TCPStore 15from torch.distributed.elastic.rendezvous import ( 16 RendezvousHandler, 17 RendezvousInfo, 18 RendezvousParameters, 19 RendezvousStoreInfo, 20) 21from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint 22 23 24__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] 25 26logger = logging.getLogger(__name__) 27 28_default_timeout_seconds = 600 29 30 31class StaticTCPRendezvous(RendezvousHandler): 32 """ 33 Static rendezvous that is a wrapper around the TCPStore. 34 35 Creates TCPStore based on the input parameters with the 36 listener on the agent with group_rank=0 37 """ 38 39 def __init__( 40 self, 41 master_addr: str, 42 master_port: int, 43 rank: int, 44 world_size: int, 45 run_id: str, 46 timeout: int, 47 ): 48 self.master_addr = master_addr 49 self.master_port = master_port 50 self.rank = rank 51 self.world_size = world_size 52 self.run_id = run_id 53 self.timeout = datetime.timedelta(seconds=timeout) 54 self._store: Optional[Store] = None 55 56 def get_backend(self) -> str: 57 return "static" 58 59 @property 60 def use_agent_store(self) -> bool: 61 return True 62 63 def next_rendezvous(self) -> RendezvousInfo: 64 logger.info("Creating TCPStore as the c10d::Store implementation") 65 is_master = self.rank == 0 66 if not self._store: 67 self._store = TCPStore( # type: ignore[call-arg] 68 self.master_addr, 69 self.master_port, 70 self.world_size, 71 is_master, 72 self.timeout, 73 multi_tenant=True, 74 ) 75 store = PrefixStore(self.run_id, self._store) 76 # TCPStore server instance is used by trainer code 77 bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port) 78 return RendezvousInfo( 79 store, 80 self.rank, 81 self.world_size, 82 bootstrap_store_info, 83 ) 84 85 def is_closed(self): 86 return False 87 88 def set_closed(self): 89 pass 90 91 def num_nodes_waiting(self): 92 return 0 93 94 def get_run_id(self) -> str: 95 return self.run_id 96 97 def shutdown(self) -> bool: 98 return True 99 100 101def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: 102 if "rank" not in params.config: 103 raise ValueError( 104 "rank is absent in RendezvousParameters." 105 "Try add --node-rank to the cmd request" 106 ) 107 endpoint = params.endpoint.strip() 108 if not endpoint: 109 raise ValueError( 110 "endpoint is absent in RendezvousParameters" 111 "Try add --master-port and --master-addr to the cmd request" 112 ) 113 master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) 114 if master_port == -1: 115 raise ValueError( 116 f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" 117 ) 118 world_size = params.max_nodes 119 rank = cast(int, params.config.get("rank")) 120 run_id = params.run_id 121 if "timeout" in params.config: 122 timeout = int(params.config["timeout"]) 123 else: 124 timeout = _default_timeout_seconds 125 126 return StaticTCPRendezvous( 127 master_addr, master_port, rank, world_size, run_id, timeout 128 ) 129