1# mypy: allow-untyped-defs 2 3from typing import Dict, Tuple 4 5import torch 6import torch.distributed.rpc as rpc 7from torch import Tensor 8from torch.distributed.rpc import RRef 9from torch.testing._internal.dist_utils import ( 10 dist_init, 11 worker_name, 12 wait_until_pending_futures_and_users_flushed 13) 14from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 15 RpcAgentTestFixture, 16) 17 18 19@torch.jit.script 20def two_args_two_kwargs( 21 first_arg, 22 second_arg, 23 first_kwarg=torch.tensor([3, 3]), 24 second_kwarg=torch.tensor([4, 4]), 25): 26 return first_arg + second_arg + first_kwarg + second_kwarg 27 28 29@torch.jit.script 30def script_rpc_async_call( 31 dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] 32): 33 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) 34 ret = fut.wait() 35 return ret 36 37 38@torch.jit.script 39def rpc_async_call_with_timeout( 40 dst_worker_name: str, 41 args: Tuple[Tensor, Tensor], 42 kwargs: Dict[str, Tensor], 43 timeout: float, 44): 45 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) 46 ret = fut.wait() 47 return ret 48 49 50@torch.jit.script 51def rpc_async_call_with_timeout_future_ret( 52 dst_worker_name: str, 53 args: Tuple[Tensor, Tensor], 54 kwargs: Dict[str, Tensor], 55 timeout: float, 56): 57 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) 58 return fut 59 60 61@torch.jit.script 62def rpc_async_call_future_ret( 63 dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] 64): 65 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) 66 return fut 67 68@torch.jit.script 69def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: 70 return rref_var.to_here() 71 72@torch.jit.script 73def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor: 74 return rref_var.to_here(timeout) 75 76@torch.jit.script 77def rpc_async_with_rref_arg(dst_worker_name: str, args: Tuple[RRef[Tensor]]) -> Tensor: 78 fut = rpc.rpc_async(dst_worker_name, rref_to_here, args) 79 ret = fut.wait() 80 return ret 81 82 83class JitFaultyAgentRpcTest(RpcAgentTestFixture): 84 """ 85 Run tests for rpc_async in JIT under the faulty agent test fixture to test 86 arbitrary timeouts. 87 """ 88 @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) 89 def test_timeout_in_torchscript_function(self): 90 # Call rpc_async + fut.wait() in torchscript function and ensure that 91 # timeout is raised. 92 if self.rank != 0: 93 return 94 95 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 96 97 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 98 kwargs = { 99 "first_kwarg": torch.tensor([2, 2]), 100 "second_kwarg": torch.tensor([3, 3]), 101 } 102 expected_error = self.get_timeout_error_regex() 103 # Ensure that we get a timeout if we override the default timeout and 104 # the RPC takes longer to execute. 105 with self.assertRaisesRegex(RuntimeError, expected_error): 106 rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5) 107 108 # Ensure that we timeout if we don't specify a timeout but the default 109 # is less than the RPC takes to execute. 110 rpc._set_rpc_timeout(0.001) 111 with self.assertRaisesRegex(RuntimeError, expected_error): 112 script_rpc_async_call( 113 dst_worker_name, args, kwargs 114 ) 115 116 # Ensure that we run to completion if zero timeout is specified. 117 ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0) 118 self.assertEqual(ret, torch.tensor([8, 8])) 119 # reset for clean shutdown 120 rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 121 122 @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) 123 def test_timeout_in_python(self): 124 # Ensures timeouts are raised if we call rpc_async from within a 125 # torchscript function, but wait on the future in python. 126 if self.rank != 0: 127 return 128 129 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 130 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 131 kwargs = { 132 "first_kwarg": torch.tensor([2, 2]), 133 "second_kwarg": torch.tensor([3, 3]), 134 } 135 expected_error = self.get_timeout_error_regex() 136 137 fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5) 138 with self.assertRaisesRegex(RuntimeError, expected_error): 139 fut.wait() 140 141 # Ensure timeout if we don't specify but the default is less than the 142 # RPC takes to execute. 143 rpc._set_rpc_timeout(0.001) 144 fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs) 145 with self.assertRaisesRegex(RuntimeError, expected_error): 146 fut.wait() 147 148 # Ensure run to completion if zero timeout is specified 149 fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0) 150 result = fut.wait() 151 self.assertEqual(result, torch.tensor([8, 8])) 152 # reset for clean shutdown 153 rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 154 155 @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) 156 def test_remote_timeout_to_here_in_jit(self): 157 # Test that calling to_here() in JIT will raise timeout error if 158 # rpc.remote failed. 159 if self.rank != 0: 160 return 161 dst_rank = (self.rank + 1) % self.world_size 162 dst_worker = f"worker{dst_rank}" 163 rref = rpc.remote( 164 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 165 ) 166 # Will ensure error handling callbacks are run. 167 wait_until_pending_futures_and_users_flushed() 168 # Call to_here() within a ScriptFunction and ensure it raises 169 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 170 rref_to_here(rref) 171 172 @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1}) 173 def test_rref_to_here_timeout_in_jit(self): 174 if self.rank != 0: 175 return 176 177 dst_rank = (self.rank + 1) % self.world_size 178 dst_worker = f"worker{dst_rank}" 179 rref = rpc.remote( 180 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 181 ) 182 expected_error = self.get_timeout_error_regex() 183 with self.assertRaisesRegex(RuntimeError, expected_error): 184 rref_to_here_with_timeout(rref, 0.01) 185 186 rref_to_here_with_timeout(rref, 100) 187 188 @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) 189 def test_rref_timeout_pickle_in_jit(self): 190 if self.rank != 0: 191 return 192 dst_rank = (self.rank + 1) % self.world_size 193 dst_worker = f"worker{dst_rank}" 194 rref = rpc.remote( 195 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 196 ) 197 # Will ensure error handling callbacks are run. 198 wait_until_pending_futures_and_users_flushed() 199 # Call RPC with RRef arg in JIT, which will go through JIT pickling and 200 # ensure error is raised. 201 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 202 rpc_async_with_rref_arg(dst_worker, (rref, )) 203 204 @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) 205 def test_rref_timeout_pickle_script_func(self): 206 # Similar to above test, but calls python rpc with script function. 207 if self.rank != 0: 208 return 209 dst_rank = (self.rank + 1) % self.world_size 210 dst_worker = f"worker{dst_rank}" 211 rref = rpc.remote( 212 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 213 ) 214 # Will ensure error handling callbacks are run. 215 wait_until_pending_futures_and_users_flushed() 216 # Call RPC with script function that takes RRef, ensure timeout during pickling 217 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 218 rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, )) 219