• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2
3from typing import Dict, Tuple
4
5import torch
6import torch.distributed.rpc as rpc
7from torch import Tensor
8from torch.distributed.rpc import RRef
9from torch.testing._internal.dist_utils import (
10    dist_init,
11    worker_name,
12    wait_until_pending_futures_and_users_flushed
13)
14from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
15    RpcAgentTestFixture,
16)
17
18
19@torch.jit.script
20def two_args_two_kwargs(
21    first_arg,
22    second_arg,
23    first_kwarg=torch.tensor([3, 3]),
24    second_kwarg=torch.tensor([4, 4]),
25):
26    return first_arg + second_arg + first_kwarg + second_kwarg
27
28
29@torch.jit.script
30def script_rpc_async_call(
31    dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
32):
33    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
34    ret = fut.wait()
35    return ret
36
37
38@torch.jit.script
39def rpc_async_call_with_timeout(
40    dst_worker_name: str,
41    args: Tuple[Tensor, Tensor],
42    kwargs: Dict[str, Tensor],
43    timeout: float,
44):
45    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
46    ret = fut.wait()
47    return ret
48
49
50@torch.jit.script
51def rpc_async_call_with_timeout_future_ret(
52    dst_worker_name: str,
53    args: Tuple[Tensor, Tensor],
54    kwargs: Dict[str, Tensor],
55    timeout: float,
56):
57    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
58    return fut
59
60
61@torch.jit.script
62def rpc_async_call_future_ret(
63    dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
64):
65    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
66    return fut
67
68@torch.jit.script
69def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
70    return rref_var.to_here()
71
72@torch.jit.script
73def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
74    return rref_var.to_here(timeout)
75
76@torch.jit.script
77def rpc_async_with_rref_arg(dst_worker_name: str, args: Tuple[RRef[Tensor]]) -> Tensor:
78    fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
79    ret = fut.wait()
80    return ret
81
82
83class JitFaultyAgentRpcTest(RpcAgentTestFixture):
84    """
85    Run tests for rpc_async in JIT under the faulty agent test fixture to test
86    arbitrary timeouts.
87    """
88    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
89    def test_timeout_in_torchscript_function(self):
90        # Call rpc_async + fut.wait() in torchscript function and ensure that
91        # timeout is raised.
92        if self.rank != 0:
93            return
94
95        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
96
97        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
98        kwargs = {
99            "first_kwarg": torch.tensor([2, 2]),
100            "second_kwarg": torch.tensor([3, 3]),
101        }
102        expected_error = self.get_timeout_error_regex()
103        # Ensure that we get a timeout if we override the default timeout and
104        # the RPC takes longer to execute.
105        with self.assertRaisesRegex(RuntimeError, expected_error):
106            rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
107
108        # Ensure that we timeout if we don't specify a timeout but the default
109        # is less than the RPC takes to execute.
110        rpc._set_rpc_timeout(0.001)
111        with self.assertRaisesRegex(RuntimeError, expected_error):
112            script_rpc_async_call(
113                dst_worker_name, args, kwargs
114            )
115
116        # Ensure that we run to completion if zero timeout is specified.
117        ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
118        self.assertEqual(ret, torch.tensor([8, 8]))
119        # reset for clean shutdown
120        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
121
122    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
123    def test_timeout_in_python(self):
124        # Ensures timeouts are raised if we call rpc_async from within a
125        # torchscript function, but wait on the future in python.
126        if self.rank != 0:
127            return
128
129        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
130        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
131        kwargs = {
132            "first_kwarg": torch.tensor([2, 2]),
133            "second_kwarg": torch.tensor([3, 3]),
134        }
135        expected_error = self.get_timeout_error_regex()
136
137        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
138        with self.assertRaisesRegex(RuntimeError, expected_error):
139            fut.wait()
140
141        # Ensure timeout if we don't specify but the default is less than the
142        # RPC takes to execute.
143        rpc._set_rpc_timeout(0.001)
144        fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
145        with self.assertRaisesRegex(RuntimeError, expected_error):
146            fut.wait()
147
148        # Ensure run to completion if zero timeout is specified
149        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
150        result = fut.wait()
151        self.assertEqual(result, torch.tensor([8, 8]))
152        # reset for clean shutdown
153        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
154
155    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
156    def test_remote_timeout_to_here_in_jit(self):
157        # Test that calling to_here() in JIT will raise timeout error if
158        # rpc.remote failed.
159        if self.rank != 0:
160            return
161        dst_rank = (self.rank + 1) % self.world_size
162        dst_worker = f"worker{dst_rank}"
163        rref = rpc.remote(
164            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
165        )
166        # Will ensure error handling callbacks are run.
167        wait_until_pending_futures_and_users_flushed()
168        # Call to_here() within a ScriptFunction and ensure it raises
169        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
170            rref_to_here(rref)
171
172    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
173    def test_rref_to_here_timeout_in_jit(self):
174        if self.rank != 0:
175            return
176
177        dst_rank = (self.rank + 1) % self.world_size
178        dst_worker = f"worker{dst_rank}"
179        rref = rpc.remote(
180            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
181        )
182        expected_error = self.get_timeout_error_regex()
183        with self.assertRaisesRegex(RuntimeError, expected_error):
184            rref_to_here_with_timeout(rref, 0.01)
185
186        rref_to_here_with_timeout(rref, 100)
187
188    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
189    def test_rref_timeout_pickle_in_jit(self):
190        if self.rank != 0:
191            return
192        dst_rank = (self.rank + 1) % self.world_size
193        dst_worker = f"worker{dst_rank}"
194        rref = rpc.remote(
195            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
196        )
197        # Will ensure error handling callbacks are run.
198        wait_until_pending_futures_and_users_flushed()
199        # Call RPC with RRef arg in JIT, which will go through JIT pickling and
200        # ensure error is raised.
201        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
202            rpc_async_with_rref_arg(dst_worker, (rref, ))
203
204    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
205    def test_rref_timeout_pickle_script_func(self):
206        # Similar to above test, but calls python rpc with script function.
207        if self.rank != 0:
208            return
209        dst_rank = (self.rank + 1) % self.world_size
210        dst_worker = f"worker{dst_rank}"
211        rref = rpc.remote(
212            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
213        )
214        # Will ensure error handling callbacks are run.
215        wait_until_pending_futures_and_users_flushed()
216        # Call RPC with script function that takes RRef, ensure timeout during pickling
217        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
218            rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
219