• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import datetime
11import logging
12from typing import cast, Optional
13
14from torch.distributed import PrefixStore, Store, TCPStore
15from torch.distributed.elastic.rendezvous import (
16    RendezvousHandler,
17    RendezvousInfo,
18    RendezvousParameters,
19    RendezvousStoreInfo,
20)
21from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
22
23
24__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"]
25
26logger = logging.getLogger(__name__)
27
28_default_timeout_seconds = 600
29
30
31class StaticTCPRendezvous(RendezvousHandler):
32    """
33    Static rendezvous that is a wrapper around the TCPStore.
34
35    Creates TCPStore based on the input parameters with the
36    listener on the agent with group_rank=0
37    """
38
39    def __init__(
40        self,
41        master_addr: str,
42        master_port: int,
43        rank: int,
44        world_size: int,
45        run_id: str,
46        timeout: int,
47    ):
48        self.master_addr = master_addr
49        self.master_port = master_port
50        self.rank = rank
51        self.world_size = world_size
52        self.run_id = run_id
53        self.timeout = datetime.timedelta(seconds=timeout)
54        self._store: Optional[Store] = None
55
56    def get_backend(self) -> str:
57        return "static"
58
59    @property
60    def use_agent_store(self) -> bool:
61        return True
62
63    def next_rendezvous(self) -> RendezvousInfo:
64        logger.info("Creating TCPStore as the c10d::Store implementation")
65        is_master = self.rank == 0
66        if not self._store:
67            self._store = TCPStore(  # type: ignore[call-arg]
68                self.master_addr,
69                self.master_port,
70                self.world_size,
71                is_master,
72                self.timeout,
73                multi_tenant=True,
74            )
75        store = PrefixStore(self.run_id, self._store)
76        # TCPStore server instance is used by trainer code
77        bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port)
78        return RendezvousInfo(
79            store,
80            self.rank,
81            self.world_size,
82            bootstrap_store_info,
83        )
84
85    def is_closed(self):
86        return False
87
88    def set_closed(self):
89        pass
90
91    def num_nodes_waiting(self):
92        return 0
93
94    def get_run_id(self) -> str:
95        return self.run_id
96
97    def shutdown(self) -> bool:
98        return True
99
100
101def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
102    if "rank" not in params.config:
103        raise ValueError(
104            "rank is absent in RendezvousParameters."
105            "Try add --node-rank to the cmd request"
106        )
107    endpoint = params.endpoint.strip()
108    if not endpoint:
109        raise ValueError(
110            "endpoint is absent in RendezvousParameters"
111            "Try add --master-port and --master-addr to the cmd request"
112        )
113    master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
114    if master_port == -1:
115        raise ValueError(
116            f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
117        )
118    world_size = params.max_nodes
119    rank = cast(int, params.config.get("rank"))
120    run_id = params.run_id
121    if "timeout" in params.config:
122        timeout = int(params.config["timeout"])
123    else:
124        timeout = _default_timeout_seconds
125
126    return StaticTCPRendezvous(
127        master_addr, master_port, rank, world_size, run_id, timeout
128    )
129