# Owner(s): ["oncall: r2p"] # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy import os import pickle import socket import threading import time from abc import ABC, abstractmethod from base64 import b64encode from datetime import datetime, timedelta from typing import Callable, cast, Optional, Tuple from unittest import TestCase from unittest.mock import call, MagicMock, Mock, patch, PropertyMock import torch.distributed as dist from torch.distributed import HashStore, Store from torch.distributed.elastic.rendezvous import ( RendezvousClosedError, RendezvousError, RendezvousInfo, RendezvousParameters, RendezvousStateError, RendezvousStoreInfo, RendezvousTimeoutError, ) from torch.distributed.elastic.rendezvous.dynamic_rendezvous import ( _Action, _BackendRendezvousStateHolder, _DistributedRendezvousOpExecutor, _NodeDesc, _NodeDescGenerator, _RendezvousCloseOp, _RendezvousContext, _RendezvousExitOp, _RendezvousJoinOp, _RendezvousKeepAliveOp, _RendezvousState, _RendezvousStateHolder, create_handler, DynamicRendezvousHandler, RendezvousBackend, RendezvousSettings, RendezvousTimeout, Token, ) class CustomAssertMixin: assertDictEqual: Callable def assert_state_equal( self, actual: _RendezvousState, expected: _RendezvousState ) -> None: self.assertDictEqual(vars(actual), vars(expected)) def assert_state_empty(self, actual: _RendezvousState) -> None: self.assertDictEqual(vars(actual), vars(_RendezvousState())) class RendezvousTimeoutTest(TestCase): def test_init_initializes_timeout(self) -> None: timeout = RendezvousTimeout( timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70), timedelta(seconds=80), ) self.assertEqual(timeout.join, timedelta(seconds=50)) self.assertEqual(timeout.last_call, timedelta(seconds=60)) self.assertEqual(timeout.close, timedelta(seconds=70)) self.assertEqual(timeout.heartbeat, timedelta(seconds=80)) def test_init_initializes_timeout_if_no_timeout_is_specified(self) -> None: timeout = RendezvousTimeout() self.assertEqual(timeout.join, timedelta(seconds=600)) self.assertEqual(timeout.last_call, timedelta(seconds=30)) self.assertEqual(timeout.close, timedelta(seconds=30)) self.assertEqual(timeout.heartbeat, timedelta(seconds=5)) def test_init_raises_error_if_timeout_is_not_positive(self) -> None: join_timeouts = [timedelta(seconds=0), timedelta(seconds=-1)] for join_timeout in join_timeouts: with self.subTest(join_timeout=join_timeout): with self.assertRaisesRegex( ValueError, rf"^The join timeout \({join_timeout}\) must be positive.$", ): timeout = RendezvousTimeout(join_timeout) class NodeDescTest(TestCase): def test_repr(self) -> None: desc = _NodeDesc("dummy_fqdn", 3, 5) self.assertEqual(repr(desc), "dummy_fqdn_3_5") def test_hash(self) -> None: desc1 = _NodeDesc("dummy_fqdn", 2, 4) desc2 = _NodeDesc("dummy_fqdn", 3, 5) descs = {desc1, desc2} self.assertIn(desc1, descs) self.assertIn(desc2, descs) class NodeDescGeneratorTest(TestCase): def test_generate(self) -> None: desc_generator = _NodeDescGenerator() fqdn = socket.getfqdn() pid = os.getpid() for local_id in range(4): with self.subTest(fqdn=fqdn, pid=pid, local_id=local_id): desc = desc_generator.generate() self.assertEqual(repr(desc), f"{fqdn}_{pid}_{local_id}") class RendezvousStateTest(TestCase): def test_encoded_size_is_within_expected_limit(self) -> None: state = _RendezvousState() state.round = 1 state.complete = True state.deadline = datetime.utcnow() state.closed = True # fmt: off expected_max_sizes = ( ( 5, 2 * (2 ** 10),), # 10 machines <= 2KB # noqa: E201, E241, E262 ( 50, 16 * (2 ** 10),), # 100 machines <= 16KB # noqa: E201, E241, E262 ( 500, 160 * (2 ** 10),), # 1000 machines <= 160KB # noqa: E201, E241, E262 (5000, 1600 * (2 ** 10),), # 10000 machines <= 1.6MB # noqa: E201, E241, E262 ) # fmt: on for num_nodes, max_byte_size in expected_max_sizes: with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size): for i in range(num_nodes): node_running = _NodeDesc( f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i ) node_waiting = _NodeDesc( f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i ) state.participants[node_running] = i state.wait_list.add(node_waiting) state.last_heartbeats[node_running] = datetime.utcnow() state.last_heartbeats[node_waiting] = datetime.utcnow() bits = pickle.dumps(state) base64_bits = b64encode(bits) self.assertLessEqual(len(base64_bits), max_byte_size) class FakeRendezvousBackend(RendezvousBackend): _state: Optional[bytes] _token: int def __init__(self) -> None: self._state = None self._token = 0 @property def name(self) -> str: return "fake_backend" def get_state(self) -> Optional[Tuple[bytes, Token]]: if self._token == 0: return None return self._state, self._token # type: ignore[return-value] def set_state( self, state: bytes, token: Optional[Token] = None ) -> Optional[Tuple[bytes, Token, bool]]: if token is None: token = 0 if token == self._token: self._state = state self._token += 1 has_set = True else: has_set = False return self._state, self._token, has_set # type: ignore[return-value] def get_state_internal(self) -> _RendezvousState: return pickle.loads(cast(bytes, self._state)) def set_state_internal(self, state: _RendezvousState) -> None: self._state = pickle.dumps(state) self._token += 1 def corrupt_state(self) -> None: self._state = b"corrupt_state" self._token += 1 class BackendRendezvousStateHolderTest(TestCase, CustomAssertMixin): def setUp(self) -> None: self._backend = FakeRendezvousBackend() mock_get_state = MagicMock(wraps=self._backend.get_state) mock_set_state = MagicMock(wraps=self._backend.set_state) self._mock_backend = Mock() self._mock_backend.get_state = mock_get_state self._mock_backend.set_state = mock_set_state setattr(self._backend, "get_state", mock_get_state) # noqa: B010 setattr(self._backend, "set_state", mock_set_state) # noqa: B010 self._settings = RendezvousSettings( run_id="dummy_run_id", min_nodes=1, max_nodes=1, timeout=RendezvousTimeout(), keep_alive_interval=timedelta(seconds=30), keep_alive_max_attempt=3, ) self._cache_duration = 0 self._now = datetime(2000, 1, 1, hour=0, minute=0) self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime" ) mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now def tearDown(self) -> None: self._datetime_patch.stop() def _create_state(self) -> _RendezvousState: state = _RendezvousState() state.round = 999 state.complete = True state.deadline = self._now state.closed = True state.participants = { _NodeDesc("dummy1", 1, 1): 0, _NodeDesc("dummy2", 1, 1): 1, _NodeDesc("dummy3", 1, 1): 2, } state.wait_list = { _NodeDesc("dummy4", 1, 1), _NodeDesc("dummy5", 1, 1), } state.last_heartbeats = { _NodeDesc("dummy1", 1, 1): self._now, _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15), _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30), _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60), _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90), } return state def _create_state_holder(self) -> _BackendRendezvousStateHolder: return _BackendRendezvousStateHolder( self._backend, self._settings, self._cache_duration ) def test_init_initializes_state_holder(self) -> None: state_holder = self._create_state_holder() self.assert_state_empty(state_holder.state) self._mock_backend.assert_not_called() def test_sync_gets_empty_state_if_backend_state_does_not_exist(self) -> None: state_holder = self._create_state_holder() has_set = state_holder.sync() self.assertIsNone(has_set) self.assert_state_empty(state_holder.state) self.assertEqual(self._mock_backend.get_state.call_count, 1) self.assertEqual(self._mock_backend.set_state.call_count, 0) def test_sync_gets_backend_state_if_local_state_is_clean(self) -> None: state_holder = self._create_state_holder() expected_state = self._create_state() for attempt in range(1, 4): with self.subTest(attempt=attempt): expected_state.round = attempt self._backend.set_state_internal(expected_state) has_set = state_holder.sync() self.assertIsNone(has_set) self.assert_state_equal(state_holder.state, expected_state) self.assertEqual(self._mock_backend.get_state.call_count, 1) self.assertEqual(self._mock_backend.set_state.call_count, 0) self._mock_backend.reset_mock() def test_sync_gets_backend_state_if_local_state_is_old_and_dirty(self) -> None: state_holder = self._create_state_holder() expected_state = self._create_state() for attempt in range(1, 4): with self.subTest(attempt=attempt): self._backend.set_state_internal(expected_state) # Increment token. state_holder.state.round = attempt state_holder.mark_dirty() has_set = state_holder.sync() self.assertFalse(has_set) self.assert_state_equal(state_holder.state, expected_state) self.assertEqual(self._mock_backend.get_state.call_count, 0) self.assertEqual(self._mock_backend.set_state.call_count, 1) self._mock_backend.reset_mock() def test_sync_sets_backend_state_if_local_state_is_new_and_dirty(self) -> None: state_holder = self._create_state_holder() for attempt in range(1, 4): with self.subTest(attempt=attempt): state_holder.state.round = attempt state_holder.mark_dirty() has_set = state_holder.sync() self.assertTrue(has_set) expected_state = self._backend.get_state_internal() self.assert_state_equal(state_holder.state, expected_state) self.assertEqual(self._mock_backend.get_state.call_count, 0) self.assertEqual(self._mock_backend.set_state.call_count, 1) self._mock_backend.reset_mock() def test_sync_uses_cached_state_if_cache_duration_is_specified(self) -> None: state = self._create_state() self._backend.set_state_internal(state) with patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time" ) as mock_time: for cache_duration in [1, 5, 10]: with self.subTest(cache_duration=cache_duration): self._cache_duration = cache_duration state_holder = self._create_state_holder() mock_time.monotonic.return_value = 5 state_holder.sync() has_set = state_holder.sync() self.assertIsNone(has_set) self.assertEqual(self._mock_backend.get_state.call_count, 1) self.assertEqual(self._mock_backend.set_state.call_count, 0) mock_time.monotonic.return_value = 5 + self._cache_duration state_holder.sync() has_set = state_holder.sync() self.assertIsNone(has_set) self.assertEqual(self._mock_backend.get_state.call_count, 1) self.assertEqual(self._mock_backend.set_state.call_count, 0) self._mock_backend.get_state.reset_mock() def test_sync_gets_backend_state_if_cached_state_has_expired(self) -> None: state = self._create_state() self._backend.set_state_internal(state) with patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time" ) as mock_time: self._cache_duration = 1 state_holder = self._create_state_holder() mock_time.monotonic.return_value = 5 state_holder.sync() has_set = state_holder.sync() self.assertIsNone(has_set) self.assertEqual(self._mock_backend.get_state.call_count, 1) self.assertEqual(self._mock_backend.set_state.call_count, 0) mock_time.monotonic.return_value = 5 + self._cache_duration + 0.01 state_holder.sync() has_set = state_holder.sync() self.assertIsNone(has_set) self.assertEqual(self._mock_backend.get_state.call_count, 2) self.assertEqual(self._mock_backend.set_state.call_count, 0) def test_sync_sanitizes_state(self) -> None: state = self._create_state() expected_state = copy.deepcopy(state) dead_node1 = _NodeDesc("dead1", 1, 1) dead_node2 = _NodeDesc("dead2", 1, 1) dead_node3 = _NodeDesc("dead3", 1, 1) dead_node4 = _NodeDesc("dead4", 1, 1) dead_node5 = _NodeDesc("dead5", 1, 1) state.last_heartbeats[dead_node1] = self._now - timedelta(seconds=91) state.last_heartbeats[dead_node2] = self._now - timedelta(seconds=100) state.last_heartbeats[dead_node3] = self._now - timedelta(seconds=110) state.last_heartbeats[dead_node4] = self._now - timedelta(seconds=120) state.last_heartbeats[dead_node5] = self._now - timedelta(seconds=130) state.participants[dead_node1] = 0 state.participants[dead_node2] = 0 state.participants[dead_node3] = 0 state.wait_list.add(dead_node4) state.wait_list.add(dead_node5) self._backend.set_state_internal(state) state_holder = self._create_state_holder() state_holder.sync() self.assert_state_equal(state_holder.state, expected_state) def test_sync_sanitizes_state_if_no_participants_is_left(self) -> None: state = self._create_state() expected_state = copy.deepcopy(state) for node in state.last_heartbeats: state.last_heartbeats[node] = self._now - timedelta(seconds=100) expected_state.complete = False expected_state.round = 1000 expected_state.participants = {} expected_state.wait_list = set() expected_state.last_heartbeats = {} self._backend.set_state_internal(state) state_holder = self._create_state_holder() state_holder.sync() self.assert_state_equal(state_holder.state, expected_state) def test_sync_raises_error_if_backend_state_is_corrupt(self) -> None: self._backend.corrupt_state() state_holder = self._create_state_holder() with self.assertRaisesRegex( RendezvousStateError, r"^The rendezvous state is corrupt. See inner exception for details.$", ): state_holder.sync() class FakeRendezvousStateHolder(_RendezvousStateHolder): _state: _RendezvousState _dirty: Optional[bool] def __init__(self) -> None: self._state = _RendezvousState() self._dirty = None @property def state(self) -> _RendezvousState: return self._state @state.setter def state(self, value) -> None: self._state = value def sync(self) -> Optional[bool]: self._dirty, dirty = None, self._dirty return dirty def mark_dirty(self) -> None: self._dirty = True class DistributedRendezvousOpExecutorTest(TestCase, CustomAssertMixin): def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._state_holder = FakeRendezvousStateHolder() mock_sync = MagicMock(wraps=self._state_holder.sync) mock_mark = MagicMock(wraps=self._state_holder.mark_dirty) self._mock_state_holder = Mock() self._mock_state_holder.sync = mock_sync self._mock_state_holder.mark = mock_mark setattr(self._state_holder, "sync", mock_sync) # noqa: B010 setattr(self._state_holder, "mark_dirty", mock_mark) # noqa: B010 self._state = self._state_holder.state self._min_nodes = 1 self._max_nodes = 1 self._timeout = RendezvousTimeout() self._now = datetime(2000, 1, 1, hour=0, minute=0) self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime" ) mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now def tearDown(self) -> None: self._datetime_patch.stop() def _create_settings(self) -> RendezvousSettings: return RendezvousSettings( run_id="dummy_run_id", min_nodes=self._min_nodes, max_nodes=self._max_nodes, timeout=self._timeout, keep_alive_interval=timedelta(seconds=30), keep_alive_max_attempt=3, ) def _create_op_executor( self, settings: Optional[RendezvousSettings] = None ) -> _DistributedRendezvousOpExecutor: self._state_holder.state = self._state if settings is None: settings = self._create_settings() return _DistributedRendezvousOpExecutor( self._node, self._state_holder, settings ) def _run_action(self, action: _Action) -> None: op_executor = self._create_op_executor() op = MagicMock(side_effect=[action, _Action.FINISH]) op_executor.run(op, deadline=1) def _assert_action(self, action: _Action, expected_state: _RendezvousState) -> None: self._run_action(action) self.assert_state_equal(self._state, expected_state) self.assertListEqual( self._mock_state_holder.mock_calls, [call.sync(), call.mark(), call.sync()] ) def test_run_passes_expected_context_and_deadline_to_state_handler(self) -> None: settings = self._create_settings() op_executor = self._create_op_executor(settings) op = MagicMock(return_value=_Action.FINISH) op_executor.run(op, deadline=3) ctx, deadline = op.call_args[0] # args self.assertIs(ctx.node, self._node) self.assertIs(ctx.state, self._state) self.assertIs(ctx.settings, settings) self.assertEqual(deadline, 3) def test_run_keeps_alive(self) -> None: expected_state = _RendezvousState() expected_state.last_heartbeats[self._node] = self._now self._assert_action(_Action.KEEP_ALIVE, expected_state) def test_run_adds_to_participants(self) -> None: expected_state = _RendezvousState() expected_state.participants[self._node] = 0 expected_state.last_heartbeats[self._node] = self._now self._min_nodes = 2 self._max_nodes = 2 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) def test_run_adds_to_participants_if_node_was_in_waitlist(self) -> None: self._state.wait_list.add(self._node) expected_state = _RendezvousState() expected_state.participants[self._node] = 0 expected_state.last_heartbeats[self._node] = self._now self._min_nodes = 2 self._max_nodes = 2 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) def _add_participants( self, num_participants: int, state: _RendezvousState, ranked: bool = False ) -> None: for i in range(num_participants): if ranked: node = _NodeDesc(f"dummy{i}", 1, 1) rank = i else: node = _NodeDesc( f"dummy{num_participants - i - 1}", 1, 1 ) # Add in reverse. rank = 0 state.participants[node] = rank state.last_heartbeats[node] = self._now def test_run_adds_to_participants_and_starts_last_call_if_min_nodes_is_reached( self, ) -> None: for num_participants in range(3): self._state = _RendezvousState() self._add_participants(num_participants, self._state) self._state.wait_list.add(self._node) expected_state = _RendezvousState() self._add_participants(num_participants, expected_state) expected_state.participants[self._node] = 0 expected_state.last_heartbeats[self._node] = self._now expected_state.deadline = self._now + self._timeout.last_call with self.subTest(num_participants=num_participants): self._min_nodes = num_participants + 1 self._max_nodes = num_participants + 2 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) self._mock_state_holder.reset_mock() def test_run_adds_to_participants_and_completes_rendezvous_if_max_nodes_is_reached( self, ) -> None: for min_max_nodes_equal in [False, True]: for num_participants in range(3): rank = num_participants self._state = _RendezvousState() self._add_participants(num_participants, self._state) self._state.wait_list.add(self._node) self._state.deadline = self._now + self._timeout.last_call expected_state = _RendezvousState() self._add_participants(num_participants, expected_state, ranked=True) expected_state.participants[self._node] = rank expected_state.last_heartbeats[self._node] = self._now expected_state.complete = True expected_state.deadline = None with self.subTest(num_participants=num_participants): self._min_nodes = num_participants + 1 if min_max_nodes_equal else 0 self._max_nodes = num_participants + 1 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) self._mock_state_holder.reset_mock() def test_run_adds_to_waitlist(self) -> None: expected_state = _RendezvousState() expected_state.wait_list.add(self._node) expected_state.last_heartbeats[self._node] = self._now self._assert_action(_Action.ADD_TO_WAIT_LIST, expected_state) def test_run_removes_from_participants(self) -> None: for complete, last_call_deadline in [(False, self._now), (True, None)]: self._state = _RendezvousState() self._add_participants(2, self._state) self._state.participants[self._node] = 0 self._state.last_heartbeats[self._node] = self._now self._state.complete = complete self._state.deadline = last_call_deadline self._state.round = 1 expected_state = _RendezvousState() self._add_participants(2, expected_state) expected_state.complete = complete expected_state.deadline = last_call_deadline expected_state.round = 1 with self.subTest(complete=complete): self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) self._mock_state_holder.reset_mock() def test_run_removes_from_participants_and_moves_to_next_round_if_node_is_last_participant( self, ) -> None: self._state.participants[self._node] = 0 self._state.last_heartbeats[self._node] = self._now self._state.complete = True self._state.round = 1 expected_state = _RendezvousState() expected_state.complete = False expected_state.round = 2 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) def test_run_removes_from_participants_and_clears_last_call_if_rendezvous_has_less_than_min_nodes( self, ) -> None: self._add_participants(2, self._state) self._state.participants[self._node] = 0 self._state.last_heartbeats[self._node] = self._now self._state.deadline = self._now expected_state = _RendezvousState() self._add_participants(2, expected_state) self._min_nodes = 3 self._max_nodes = 4 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) def test_run_removes_from_waitlist(self) -> None: self._state.wait_list.add(self._node) self._state.last_heartbeats[self._node] = self._now expected_state = _RendezvousState() self._assert_action(_Action.REMOVE_FROM_WAIT_LIST, expected_state) def test_run_marks_rendezvous_closed(self) -> None: expected_state = _RendezvousState() expected_state.closed = True self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED, expected_state) def test_run_raises_error_if_rendezvous_is_closed(self) -> None: with self.assertRaises(RendezvousClosedError): self._run_action(_Action.ERROR_CLOSED) self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()]) def test_run_raises_error_if_operation_timed_out(self) -> None: with self.assertRaises(RendezvousTimeoutError): self._run_action(_Action.ERROR_TIMEOUT) self.assertListEqual(self._mock_state_holder.mock_calls, [call.sync()]) def test_run_delays_execution_if_sync_requested(self) -> None: with patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay" ) as mock_delay: self._run_action(_Action.SYNC) mock_delay.assert_called_once_with(seconds=1) self.assertListEqual( self._mock_state_holder.mock_calls, [call.sync(), call.sync()] ) class AbstractTestRendezvousOp(ABC): assertEqual: Callable def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._min_nodes = 1 self._max_nodes = 2 self._keep_alive_interval = timedelta(seconds=30) self._state = _RendezvousState() self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1 self._now = datetime(2000, 1, 1, hour=0, minute=0) self._deadline = 10 self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime" ) mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now self._time_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time" ) mock_time = self._time_patch.start() mock_time.monotonic.return_value = self._deadline def tearDown(self) -> None: self._time_patch.stop() self._datetime_patch.stop() def _get_next_action(self) -> _Action: op = self._create_op() settings = RendezvousSettings( run_id="dummy_run_id", min_nodes=self._min_nodes, max_nodes=self._max_nodes, timeout=RendezvousTimeout(), keep_alive_interval=self._keep_alive_interval, keep_alive_max_attempt=3, ) ctx = _RendezvousContext(self._node, self._state, settings) return op(ctx, self._deadline) @abstractmethod def _create_op(self) -> Callable: pass def _assert_action(self, expected_action) -> None: action = self._get_next_action() self.assertEqual(action, expected_action) class TestRendezvousExitOp(AbstractTestRendezvousOp, TestCase): def _create_op(self) -> Callable: return _RendezvousExitOp() def test_removes_from_participants_if_node_is_participant(self) -> None: self._state.participants[self._node] = 1 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS) def test_raises_timeout_if_deadline_exceeded(self) -> None: self._deadline = 0 self._state.participants[self._node] = 1 self._assert_action(_Action.ERROR_TIMEOUT) def test_finishes_if_node_is_not_participant(self) -> None: self._assert_action(_Action.FINISH) class TestRendezvousJoinOp(AbstractTestRendezvousOp, TestCase): def _create_op(self) -> Callable: return _RendezvousJoinOp() def test_raises_closed_if_rendezvous_is_closed(self) -> None: self._state.closed = True self._assert_action(_Action.ERROR_CLOSED) def test_finishes_if_rendezvous_is_complete_and_node_is_participant(self) -> None: self._state.participants[self._node] = 0 self._state.complete = True self._assert_action(_Action.FINISH) def _assert_waits_rendezvous_completion(self) -> None: keep_alive_time = self._now - self._keep_alive_interval for delta, expected_action in [ (timedelta(seconds=0), _Action.KEEP_ALIVE), (timedelta(seconds=1), _Action.SYNC), ]: self._state.last_heartbeats[self._node] = keep_alive_time + delta self._assert_action(expected_action) def test_treat_as_redundancy_for_next_rendezvous_if_rendezvous_is_complete( self, ) -> None: self._max_nodes = 1 self._state.complete = True self._assert_action(_Action.ADD_TO_REDUNDANCY_LIST) def test_waits_next_round_if_rendezvous_is_complete_and_node_is_redundant( self, ) -> None: self._state.redundancy_list.add(self._node) self._max_nodes = 1 self._state.complete = True self._assert_waits_rendezvous_completion() def test_remove_from_rednundancy_list(self) -> None: self._state.redundancy_list.add(self._node) self._max_nodes = 2 self._state.complete = True self._assert_action(_Action.REMOVE_FROM_REDUNDANCY_LIST) def test_waits_next_round_if_rendezvous_is_complete_and_node_is_in_wait_list( self, ) -> None: self._state.wait_list.add(self._node) self._state.complete = True self._assert_waits_rendezvous_completion() def test_adds_to_wait_list_if_rendezvous_is_complete_and_num_nodes_is_less_than_max_nodes( self, ) -> None: self._state.complete = True self._assert_action(_Action.ADD_TO_WAIT_LIST) def test_waits_rendezvous_to_complete_if_node_is_participant(self) -> None: self._max_nodes = 3 self._state.participants[self._node] = 0 self._state.deadline = self._now self._assert_waits_rendezvous_completion() def test_marks_rendezvous_complete_if_node_is_participant_and_last_call_deadline_exceeded( self, ) -> None: self._max_nodes = 3 self._state.participants[self._node] = 0 self._state.deadline = self._now - timedelta(seconds=1) self._assert_action(_Action.MARK_RENDEZVOUS_COMPLETE) def test_adds_to_participants(self) -> None: self._assert_action(_Action.ADD_TO_PARTICIPANTS) def test_raises_timeout_if_deadline_exceeded(self) -> None: self._deadline = 0 self._assert_action(_Action.ERROR_TIMEOUT) def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_participant( self, ) -> None: self._deadline = 0 self._state.participants[self._node] = 0 self._assert_action(_Action.ERROR_TIMEOUT) def test_raises_timeout_if_rollback_deadline_exceeded_and_node_is_in_wait_list( self, ) -> None: self._deadline = 0 self._state.wait_list.add(self._node) self._assert_action(_Action.ERROR_TIMEOUT) def test_removes_from_participants_if_timed_out_but_rollback_deadline_is_not_reached( self, ) -> None: self._deadline = 5 self._state.participants[self._node] = 0 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS) def test_removes_from_wait_list_if_timed_out_but_rollback_deadline_is_not_reached( self, ) -> None: self._deadline = 5 self._state.wait_list.add(self._node) self._assert_action(_Action.REMOVE_FROM_WAIT_LIST) def test_no_timeout_for_redundant_node(self) -> None: self._max_nodes = 1 self._deadline = 0 self._state.complete = True self._state.redundancy_list.add(self._node) self._assert_action(_Action.SYNC) def test_keep_alive_for_redundant_node(self) -> None: self._deadline = 0 self._max_nodes = 1 self._state.complete = True self._state.redundancy_list.add(self._node) keep_alive_time = self._now - self._keep_alive_interval self._state.last_heartbeats[self._node] = keep_alive_time self._assert_action(_Action.KEEP_ALIVE) class TestRendezvousCloseOp(AbstractTestRendezvousOp, TestCase): def _create_op(self) -> Callable: return _RendezvousCloseOp() def test_finishes_if_rendezvous_is_closed(self) -> None: self._state.closed = True self._assert_action(_Action.FINISH) def test_raises_timeout_if_deadline_exceeded(self) -> None: self._deadline = 0 self._assert_action(_Action.ERROR_TIMEOUT) def test_marks_rendezvous_closed(self) -> None: self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED) class TestRendezvousKeepAliveOp(AbstractTestRendezvousOp, TestCase): def _create_op(self) -> Callable: return _RendezvousKeepAliveOp() def test_updates_keep_alive_if_needed(self) -> None: keep_alive_time = self._now - self._keep_alive_interval for delta in [timedelta(seconds=0), timedelta(seconds=-1)]: with self.subTest(delta=delta): self._state.last_heartbeats[self._node] = keep_alive_time + delta self._assert_action(_Action.KEEP_ALIVE) def test_raises_timeout_if_deadlined_exceeded(self) -> None: self._deadline = 0 self._state.last_heartbeats[self._node] = self._now - self._keep_alive_interval self._assert_action(_Action.ERROR_TIMEOUT) def test_finishes_if_no_keep_alive_update_is_needed(self) -> None: delta = timedelta(seconds=1) self._state.last_heartbeats[self._node] = ( self._now - self._keep_alive_interval + delta ) self._assert_action(_Action.FINISH) class DummyStore(Store): pass class DynamicRendezvousHandlerTest(TestCase): def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._min_nodes = 1 self._max_nodes = 1 self._join_timeout: Optional[timedelta] = None self._close_timeout: Optional[timedelta] = None self._heartbeat_timeout: Optional[timedelta] = None self._keep_alive_interval = timedelta(seconds=30) self._store = DummyStore() self._mock_store_get = MagicMock(return_value=b"123") self._mock_store_set = MagicMock() setattr(self._store, "get", self._mock_store_get) # noqa: B010 setattr(self._store, "set", self._mock_store_set) # noqa: B010 self._state_holder = FakeRendezvousStateHolder() self._mock_sync = MagicMock(wraps=self._state_holder.sync) setattr(self._state_holder, "sync", self._mock_sync) # noqa: B010 self._state = self._state_holder.state self._tcp_store_mock = DummyStore() patcher = patch.object( DynamicRendezvousHandler, "_create_tcp_store_server", return_value=self._tcp_store_mock, ) patcher.start() self.addCleanup(patcher.stop) def _create_handler(self) -> DynamicRendezvousHandler: settings = RendezvousSettings( run_id="dummy_run_id", min_nodes=self._min_nodes, max_nodes=self._max_nodes, timeout=RendezvousTimeout( join=self._join_timeout, close=self._close_timeout, heartbeat=self._heartbeat_timeout, ), keep_alive_interval=self._keep_alive_interval, keep_alive_max_attempt=3, ) self._state_holder.state = self._state return DynamicRendezvousHandler( self._node, settings, "dummy_backend", self._store, self._state_holder ) def test_share_store_creates_tcp_store(self): handler = self._create_handler() shared_store_info = RendezvousStoreInfo("host", 54321) with patch.object(RendezvousStoreInfo, "build", return_value=shared_store_info): rdzv_info = handler.next_rendezvous() self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host") self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321) self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock) rdzv_info = handler.next_rendezvous() self.assertEqual(handler._shared_tcp_store_server, self._tcp_store_mock) def test_share_store_when_tcp_store(self): handler = self._create_handler() with patch.object(dist, "PrefixStore", new=Mock): handler._store = Mock(spec=dist.TCPStore) type(handler._store).host = PropertyMock(return_value="host") type(handler._store).port = PropertyMock(return_value=54321) rdzv_info = handler.next_rendezvous() self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host") self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321) self.assertEqual(handler._shared_tcp_store_server, handler._store) rdzv_info = handler.next_rendezvous() self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "host") self.assertEqual(rdzv_info.bootstrap_store_info.master_port, 54321) self.assertEqual(handler._shared_tcp_store_server, handler._store) @patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous._delay") def test_next_rendezvous_skews_the_first_join_attempt(self, mock_delay) -> None: for round, expected_call_count in [(0, True), (1, False)]: with self.subTest(round=round): self._state.round = round handler = self._create_handler() handler.next_rendezvous() self.assertEqual(mock_delay.call_count, expected_call_count) mock_delay.reset_mock() def test_next_rendezvous_returns_expected_value(self) -> None: self._state.participants[_NodeDesc("dummy1", 1, 1)] = 0 self._state.participants[_NodeDesc("dummy2", 1, 1)] = 0 self._max_nodes = 3 handler = self._create_handler() rdzv_info = handler.next_rendezvous() self.assertEqual(rdzv_info.rank, 2) self.assertEqual(rdzv_info.world_size, 3) _ = rdzv_info.store.get("dummy_key") self._mock_store_get.assert_called_with( "torch.rendezvous.dummy_run_id.0/dummy_key" ) def test_next_rendezvous_respects_the_requested_timeout(self) -> None: self._mock_sync.side_effect = lambda: time.sleep(0.3) self._join_timeout = timedelta(seconds=0.2) handler = self._create_handler() with self.assertRaises(RendezvousTimeoutError): handler.next_rendezvous() def test_next_rendezvous_moves_to_next_round_if_called_repeatedly(self) -> None: handler = self._create_handler() for i in range(4): handler.next_rendezvous() self.assertEqual(self._state.round, i) def test_is_closed_returns_expected_value(self) -> None: for closed in [False, True]: with self.subTest(closed=closed): self._state.closed = closed handler = self._create_handler() self.assertEqual(handler.is_closed(), closed) self._mock_sync.assert_called_once() self._mock_sync.reset_mock() @patch("torch.distributed.elastic.events.record_rdzv_event") def test_is_closed_records_and_raises_exceptions(self, record_mock) -> None: self._mock_sync.side_effect = RendezvousError("test error") handler = self._create_handler() with self.assertRaises(RendezvousError): handler.is_closed() record_mock.assert_called_once() def test_set_closed_closes_rendezvous(self) -> None: handler = self._create_handler() handler.set_closed() self.assertTrue(self._state.closed) def test_set_closed_respects_the_requested_timeout(self) -> None: self._mock_sync.side_effect = lambda: time.sleep(0.3) self._close_timeout = timedelta(seconds=0.2) handler = self._create_handler() with self.assertRaises(RendezvousTimeoutError): handler.set_closed() def test_set_closed_can_be_called_multiple_times(self) -> None: handler = self._create_handler() handler.set_closed() handler.set_closed() self.assertTrue(self._state.closed) @patch("torch.distributed.elastic.events.record_rdzv_event") def test_set_closed_records_and_raises_exceptions(self, record_mock) -> None: with patch.object(DynamicRendezvousHandler, "_close") as close_mock: close_mock.side_effect = RendezvousError("test error") handler = self._create_handler() with self.assertRaises(RendezvousError): handler.set_closed() record_mock.assert_called_once() def test_num_nodes_waiting_returns_expected_value(self) -> None: self._state.wait_list.add(_NodeDesc("dummy1", 1, 1)) self._state.wait_list.add(_NodeDesc("dummy2", 1, 1)) handler = self._create_handler() self.assertEqual(handler.num_nodes_waiting(), 2) self._mock_sync.assert_called_once() @patch("torch.distributed.elastic.events.record_rdzv_event") def test_num_nodes_waiting_records_and_raises_exceptions(self, record_mock) -> None: self._mock_sync.side_effect = RendezvousError("test error") handler = self._create_handler() with self.assertRaises(RendezvousError): handler.num_nodes_waiting() record_mock.assert_called_once() def test_shutdown_closes_rendezvous_and_returns_true(self) -> None: handler = self._create_handler() result = handler.shutdown() self.assertTrue(result) self.assertTrue(self._state.closed) def test_shutdown_returns_false_if_rendezvous_cannot_be_closed(self) -> None: self._mock_sync.side_effect = [RendezvousError] handler = self._create_handler() result = handler.shutdown() self.assertFalse(result) def test_shutdown_can_be_called_multiple_times(self) -> None: handler = self._create_handler() handler.shutdown() handler.shutdown() self.assertTrue(self._state.closed) @patch("torch.distributed.elastic.events.record_rdzv_event") def test_shutdown_records_and_raises_exceptions(self, record_mock) -> None: with patch.object(DynamicRendezvousHandler, "_close") as close_mock: close_mock.side_effect = RuntimeError("test error") handler = self._create_handler() with self.assertRaises(RuntimeError): handler.shutdown() record_mock.assert_called_once() @patch("torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") def test_keep_alive_updates_last_heartbeat(self, mock_datetime) -> None: now = datetime(2000, 1, 1, hour=0, minute=0) mock_datetime.utcnow.return_value = now self._state.last_heartbeats[self._node] = now - (self._keep_alive_interval * 2) handler = self._create_handler() handler._keep_alive() self.assertEqual(self._state.last_heartbeats[self._node], now) def _assert_keep_alive_swallows_rendezvous_errors(self) -> None: last_heartbeat_time = datetime.utcnow() - (self._keep_alive_interval * 2) self._state.last_heartbeats[self._node] = last_heartbeat_time handler = self._create_handler() handler._keep_alive() self.assertEqual(self._state.last_heartbeats[self._node], last_heartbeat_time) def test_keep_alive_swallows_rendezvous_errors(self) -> None: self._mock_sync.side_effect = [RendezvousError] self._assert_keep_alive_swallows_rendezvous_errors() def test_keep_alive_respects_the_requested_timeout(self) -> None: self._mock_sync.side_effect = lambda: time.sleep(0.3) self._heartbeat_timeout = timedelta(seconds=0.2) self._assert_keep_alive_swallows_rendezvous_errors() def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_shutdown( self, ) -> None: self._node = _NodeDesc("this_node", 1, 2) name = "RendezvousKeepAliveTimer_2" handler = self._create_handler() self.assertTrue(all(t.name != name for t in threading.enumerate())) handler.next_rendezvous() self.assertTrue(any(t.name == name for t in threading.enumerate())) handler.shutdown() self.assertTrue(all(t.name != name for t in threading.enumerate())) def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_finalizer( self, ) -> None: self._node = _NodeDesc("this_node", 1, 3) name = "RendezvousKeepAliveTimer_3" handler = self._create_handler() self.assertTrue(all(t.name != name for t in threading.enumerate())) handler.next_rendezvous() self.assertTrue(any(t.name == name for t in threading.enumerate())) del handler self.assertTrue(all(t.name != name for t in threading.enumerate())) class DummyRendezvousBackend(RendezvousBackend): @property def name(self): return "dummy_backend" def get_state(self): return None def set_state(self, state, token): return None class DynamicRendezvousHandlerFromBackendTest(TestCase): def setUp(self) -> None: self._run_id = "dummy_run_id" self._store = DummyStore() self._backend = DummyRendezvousBackend() self._min_nodes = 3 self._max_nodes = 6 self._timeout: Optional[RendezvousTimeout] = RendezvousTimeout() def _create_handler(self) -> DynamicRendezvousHandler: return DynamicRendezvousHandler.from_backend( run_id=self._run_id, store=self._store, backend=self._backend, min_nodes=self._min_nodes, max_nodes=self._max_nodes, timeout=self._timeout, ) def test_init_initializes_handler(self) -> None: handler = self._create_handler() self.assertEqual(handler.get_backend(), self._backend.name) self.assertEqual(handler.get_run_id(), self._run_id) self.assertEqual(handler.settings.run_id, self._run_id) self.assertEqual(handler.settings.min_nodes, self._min_nodes) self.assertEqual(handler.settings.max_nodes, self._max_nodes) if self._timeout is None: self.assertIsNotNone(handler.settings.timeout) else: self.assertIs(handler.settings.timeout, self._timeout) def test_init_initializes_handler_if_timeout_is_not_specified(self) -> None: self._timeout = None self.test_init_initializes_handler() def test_init_initializes_handler_if_min_and_max_nodes_are_equal(self) -> None: self._min_nodes = 3 self._max_nodes = 3 self.test_init_initializes_handler() def test_init_raises_error_if_min_nodes_is_not_positive(self) -> None: for num in [0, -10]: with self.subTest(min_nodes=num): self._min_nodes = num with self.assertRaisesRegex( ValueError, rf"^The minimum number of nodes \({num}\) must be greater than zero.$", ): self._create_handler() def test_init_raises_error_if_max_nodes_is_less_than_min(self) -> None: self._min_nodes = 3 self._max_nodes = 2 with self.assertRaisesRegex( ValueError, rf"^The maximum number of nodes \({self._max_nodes}\) must be greater than or equal to " "the minimum number of nodes " rf"\({self._min_nodes}\).$", ): self._create_handler() class CreateHandlerTest(TestCase): def setUp(self) -> None: self._store = DummyStore() self._backend = DummyRendezvousBackend() self._params = RendezvousParameters( backend=self._backend.name, endpoint="dummy_endpoint", run_id="dummy_run_id", min_nodes=3, max_nodes=6, join_timeout="50", last_call_timeout="60", close_timeout="70", ) self._expected_timeout = RendezvousTimeout( timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70) ) def test_create_handler_returns_handler(self) -> None: handler = create_handler(self._store, self._backend, self._params) self.assertEqual(handler.get_backend(), self._backend.name) self.assertEqual(handler.get_run_id(), self._params.run_id) self.assertEqual(handler.settings.min_nodes, self._params.min_nodes) self.assertEqual(handler.settings.max_nodes, self._params.max_nodes) self.assertEqual(handler.settings.timeout.join, self._expected_timeout.join) self.assertEqual( handler.settings.timeout.last_call, self._expected_timeout.last_call ) self.assertEqual(handler.settings.timeout.close, self._expected_timeout.close) def test_create_handler_returns_handler_if_timeout_is_not_specified(self) -> None: del self._params.config["join_timeout"] del self._params.config["last_call_timeout"] del self._params.config["close_timeout"] self._expected_timeout = RendezvousTimeout() self.test_create_handler_returns_handler() @patch("torch.distributed.elastic.events.record_rdzv_event") def test_create_handler_records_and_raises_exceptions(self, record_mock) -> None: with patch.object(DynamicRendezvousHandler, "from_backend") as from_mock: from_mock.side_effect = RendezvousError("test error") with self.assertRaises(RendezvousError): create_handler(self._store, self._backend, self._params) record_mock.assert_called_once() def test_create_handler_rdzv_local_addr(self) -> None: params = RendezvousParameters( backend=self._backend.name, endpoint="dummy_endpoint", run_id="dummy_run_id", min_nodes=1, max_nodes=1, join_timeout="50", last_call_timeout="60", close_timeout="70", local_addr="127.0.0.2", ) store = HashStore() handler = create_handler(store, self._backend, params) rdzv_info = handler.next_rendezvous() self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "127.0.0.2") def _ignore_exception(exception_type: Exception, fn: Callable): try: fn() except exception_type as e: pass def _wait_for(condition, timeout=10, interval=1, name=None): def _wait_while(): while True: if condition(): break else: time.sleep(interval) wait_thread = threading.Thread(target=_wait_while, name=name) wait_thread.start() wait_thread.join(timeout=timeout) class _CapturingThread(threading.Thread): def __init__(self, target=None, name=None, args=None, kwargs=None): if args is None: args = () if kwargs is None: kwargs = {} threading.Thread.__init__( self, target=target, args=args, kwargs=kwargs, name=name ) self._result = None def run(self): if self._target is not None: self._result = self._target(*self._args, **self._kwargs) def join(self, *args): threading.Thread.join(self, *args) return self._result class IntegrationTest(TestCase): def setUp(self) -> None: self._store = HashStore() self._handlers = [] self._backend = _InMemoryRendezvousBackend() def tearDown(self) -> None: for handler in self._handlers: handler._stop_heartbeats() def _create_handler(self, **kwargs) -> DynamicRendezvousHandler: params = { "backend": self._backend.name, "endpoint": "dummy_endpoint", "run_id": "dummy_run_id", "min_nodes": 2, "max_nodes": 2, "join_timeout": "5", "local_addr": f"127.0.0.{len(self._handlers)}", } params.update(**kwargs) rzdv_params = RendezvousParameters(**params) handler = create_handler(self._store, self._backend, rzdv_params) self._handlers.append(handler) return handler def test_all_nodes_join_rendezvous(self) -> None: handler1 = self._create_handler(min_nodes=2, max_nodes=2) handler2 = self._create_handler(min_nodes=2, max_nodes=2) handler1_thread = _CapturingThread(target=handler1.next_rendezvous) handler2_thread = _CapturingThread(target=handler2.next_rendezvous) handler1_thread.start() handler2_thread.start() rdzv_info1: RendezvousInfo = handler1_thread.join() rdzv_info2: RendezvousInfo = handler2_thread.join() self.assertEqual(rdzv_info1.store.underlying_store, self._store) self.assertEqual(rdzv_info2.store.underlying_store, self._store) self.assertNotEqual(rdzv_info1.rank, rdzv_info2.rank) self.assertEqual(rdzv_info1.world_size, 2) self.assertEqual(rdzv_info2.world_size, 2) def test_redundancy_list(self) -> None: handler1 = self._create_handler(min_nodes=2, max_nodes=2) handler2 = self._create_handler(min_nodes=2, max_nodes=2) handler3 = self._create_handler(min_nodes=2, max_nodes=2) handler1_thread = _CapturingThread(target=handler1.next_rendezvous) handler2_thread = _CapturingThread(target=handler2.next_rendezvous) handler3_thread = _CapturingThread( target=_ignore_exception, args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()), ) handler1_thread.start() handler2_thread.start() # establish successful rendezvous handler1_thread.join() handler2_thread.join() # expect to register in redundancy list handler3_thread.start() # wait until the handler3 is registered in the redundancy list _wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list) state_and_token = self._backend.get_state() state = pickle.loads(state_and_token[0]) addresses = [node.addr for node in state.redundancy_list] self.assertListEqual(addresses, ["127.0.0.2"]) def test_redundancy_transition_to_wait_list_then_join_rendezvous(self) -> None: handler1 = self._create_handler( min_nodes=1, max_nodes=2, ) handler2 = self._create_handler( min_nodes=1, max_nodes=2, keep_alive_interval=timedelta(seconds=1), ) handler3 = self._create_handler( min_nodes=1, max_nodes=2, ) handler1_thread = _CapturingThread(target=handler1.next_rendezvous) handler2_thread = _CapturingThread(target=handler2.next_rendezvous) handler3_thread = _CapturingThread( target=_ignore_exception, args=(RendezvousTimeoutError, lambda: handler3.next_rendezvous()), ) handler1_thread.start() handler2_thread.start() # establish successful rendezvous handler1_thread.join() handler2_thread.join() handler3_thread.start() _wait_for(lambda: pickle.loads(self._backend.get_state()[0]).redundancy_list) handler2._stop_heartbeats() _wait_for( lambda: len(pickle.loads(self._backend.get_state()[0]).participants) == 1 ) _wait_for( lambda: len(pickle.loads(self._backend.get_state()[0]).wait_list) == 1 ) def test_use_agent_store_is_true_by_default(self): handler = self._create_handler( min_nodes=1, max_nodes=2, ) self.assertTrue(handler.use_agent_store) @patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"}) def test_use_agent_store_is_disabled(self): handler = self._create_handler( min_nodes=1, max_nodes=2, ) self.assertFalse(handler.use_agent_store) @patch.object(dist, "PrefixStore") def test_share_tcp_store_from_backend(self, prefix_store_class_mock): prefix_store = Mock(spec=dist.PrefixStore) prefix_store_class_mock.return_value = prefix_store tcp_store = Mock(spec=dist.TCPStore) expected_addr = "expected_address" expected_port = 54321 type(tcp_store).host = PropertyMock(return_value=expected_addr) type(tcp_store).port = PropertyMock(return_value=expected_port) # this will be injected self._store = tcp_store handler1 = self._create_handler(min_nodes=2, max_nodes=2) handler2 = self._create_handler(min_nodes=2, max_nodes=2) handler1_thread = _CapturingThread(target=handler1.next_rendezvous) handler2_thread = _CapturingThread(target=handler2.next_rendezvous) handler1_thread.start() handler2_thread.start() rdzv_info1: RendezvousInfo = handler1_thread.join() rdzv_info2: RendezvousInfo = handler2_thread.join() self.assertEqual(rdzv_info1.store, prefix_store) self.assertEqual(rdzv_info2.store, prefix_store) prefix_store_class_mock.assert_called_with( "torch.rendezvous.dummy_run_id.0", tcp_store ) self.assertEqual( rdzv_info1.bootstrap_store_info, rdzv_info2.bootstrap_store_info ) self.assertEqual(rdzv_info1.bootstrap_store_info.master_addr, expected_addr) self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, expected_port) @patch.dict(os.environ, {"TORCH_DISABLE_SHARE_RDZV_TCP_STORE": "1"}) @patch.object(dist, "PrefixStore") def test_share_tcp_store_is_disabled(self, prefix_store_class_mock): prefix_store = Mock() prefix_store_class_mock.return_value = prefix_store prefix_store.set.return_value = None prefix_store.get.return_value = b"123" tcp_store = Mock(spec=dist.TCPStore) # this will be injected self._store = tcp_store handler1 = self._create_handler(min_nodes=2, max_nodes=2) handler2 = self._create_handler(min_nodes=2, max_nodes=2) handler1_thread = _CapturingThread(target=handler1.next_rendezvous) handler2_thread = _CapturingThread(target=handler2.next_rendezvous) handler1_thread.start() handler2_thread.start() rdzv_info1: RendezvousInfo = handler1_thread.join() rdzv_info2: RendezvousInfo = handler2_thread.join() self.assertEqual(rdzv_info1.store, prefix_store) self.assertEqual(rdzv_info2.store, prefix_store) prefix_store_class_mock.assert_called_with( "torch.rendezvous.dummy_run_id.0", self._store ) self.assertEqual(rdzv_info1.bootstrap_store_info.master_port, 123) self.assertEqual(rdzv_info2.bootstrap_store_info.master_port, 123) class _InMemoryRendezvousBackend(RendezvousBackend): def __init__(self) -> None: self._lock = threading.Lock() self._state = None self._token = None @property def name(self): return "_in_memory_backend" def get_state(self): with self._lock: if self._state is None: return None return (self._state, self._token) return self._state def set_state(self, state, token): if state is None: raise ValueError("State cannot be None.") with self._lock: if token is None and self._token is not None: return None if self._token != token: return None self._state = state self._token = self._token + 1 if self._token is not None else 0