1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This file contains tests that simulate peer failures. 16 17When a peer fails during MultiWorkerMirroredStrategy training. All workers 18should get Unavailable error. 19""" 20 21import os 22 23import tensorflow as tf 24 25from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib 26from tensorflow.python.distribute import multi_process_runner 27from tensorflow.python.distribute import multi_worker_test_base 28from tensorflow.python.distribute import test_util 29from tensorflow.python.eager import test 30 31RPC_PROTOCOL = "grpc" 32 33# Put it in top level so it executes in the child processes as well. 34mwms_lib.CollectiveAllReduceExtended._enable_check_health = True 35mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3 36mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0 37# This is needed for OSS, which issues all RPCs with fail_fast=false by default. 38mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1 39 40 41def get_attempt(strategy, attempts): 42 task_type = strategy.cluster_resolver.task_type 43 task_id = strategy.cluster_resolver.task_id 44 attempts[(task_type, task_id)] = attempts.get((task_type, task_id), 0) + 1 45 return task_id, attempts[(task_type, task_id)] 46 47 48quick_exit = os._exit # pylint: disable=protected-access 49 50 51class PeerFailureTest(test.TestCase): 52 # Note that all the tests use auto_restart=True. Currently we rely on the 53 # assumption that an external system restarts failed tasks. If the assumption 54 # is not true, the remaining tasks may still hang instead of fail. 55 # 56 # In these tests we leverage the auto restart feature of MultiProcessRunner. 57 # Failed workers are restarted automatically. In reality there needs to be 58 # some job management system that does the restart, e.g. Kubernetes. 59 # 60 # Worker failures may cause problems if there're more than one collective, and 61 # the failure happens after the first collective. In this case the recovered 62 # worker will be running a different collective with the rest, which causes a 63 # deadlock. Note that collectives are common, e.g. when creating variables the 64 # initial values are broadcasted from the first worker. 65 # 66 # We use a multiprocessing.Manager().dict() object to track the attempts of 67 # each worker. We take different actions in different attempts to simuate the 68 # events in real world. E.g. some tests make a worker fail on the first 69 # attempt only, and asserts that it should recovery. 70 71 def test_creating_variable(self): 72 # This test simulates the case when a worker fails before or during creating 73 # a variable. Creating variables involve broadcasting the initial value from 74 # the first replica to all replicas. 75 76 def worker_fn(): 77 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 78 with strategy.scope(): 79 tf.Variable(1.) 80 # worker-1 dies here. 81 if strategy.cluster_resolver.task_id == 1: 82 quick_exit(1) 83 v = tf.Variable(tf.random.uniform(())) 84 return v.read_value().numpy() 85 86 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 87 mpr = multi_process_runner.MultiProcessRunner( 88 worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL) 89 mpr.start() 90 # TODO(b/151232436): Always raise UnavailableError when a peer fails. 91 with self.assertRaises( 92 (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)): 93 mpr.join(timeout=60) 94 95 def test_reduce_small_tensor(self): 96 # This test simulates the case when a worker fails before or during reducing 97 # a small tensors, e.g. reading a metric. 98 # 99 # Note that this is written for a specific corner case that used to happen 100 # only when all of the following conditions are met: 101 # - There're two workers. 102 # - They're reducing a small tensor. The definition of small varies 103 # per platform. 104 # - They're reducing a single tensor. Batched all-reduce are not affected. 105 # - It must be worker-1 that fails. 106 # Under this case, the all-reduce is effectively two send/recv operation, 107 # the first one from worker-0 to worker-1, and the second one vice versa. 108 # The first one blocks the second one. In send/recv, the sending party is 109 # not aware of the failures of the receiving party. 110 111 def worker_fn(): 112 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 113 value = tf.identity([1.]) 114 strategy.reduce("sum", value, axis=None) 115 # worker-1 dies here. 116 if strategy.cluster_resolver.task_id == 1: 117 quick_exit(1) 118 strategy.reduce("sum", value, axis=None) 119 120 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 121 mpr = multi_process_runner.MultiProcessRunner( 122 worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL) 123 mpr.start() 124 # TODO(b/151232436): Always raise UnavailableError when a peer fails. 125 with self.assertRaises( 126 (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)): 127 mpr.join(timeout=60) 128 129 130class PeerFailureRecoverTest(test.TestCase): 131 # Similar to PeerFailureTest but simulates the situation where there's some 132 # external system that automatically restarts failed workers. 133 134 def test_creating_variable(self): 135 # See PeerFailureTest.test_creating_variable 136 137 def worker_fn(attempts): 138 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 139 task_id, attempt = get_attempt(strategy, attempts) 140 with strategy.scope(): 141 tf.Variable(1.) 142 # worker-1 dies here. 143 if attempt == 1 and task_id == 1: 144 quick_exit(1) 145 v = tf.Variable(tf.random.uniform(())) 146 return v.read_value().numpy() 147 148 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 149 attempts = multi_process_runner.manager().dict() 150 mpr = multi_process_runner.MultiProcessRunner( 151 worker_fn, 152 cluster_spec, 153 rpc_layer=RPC_PROTOCOL, 154 args=(attempts,), 155 auto_restart=True) 156 mpr.start() 157 results = mpr.join(timeout=90).return_value 158 self.assertEqual(results[0], results[1]) 159 160 def test_reduce_small_tensor(self): 161 # See PeerFailureTest.test_reduce_small_tensor 162 163 def worker_fn(attempts): 164 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 165 task_id, attempt = get_attempt(strategy, attempts) 166 value = tf.identity([1.]) 167 strategy.reduce("sum", value, axis=None) 168 # worker-1 dies here. 169 if attempt == 1 and task_id == 1: 170 quick_exit(1) 171 return strategy.reduce("sum", value, axis=None).numpy() 172 173 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 174 attempts = multi_process_runner.manager().dict() 175 mpr = multi_process_runner.MultiProcessRunner( 176 worker_fn, 177 cluster_spec, 178 rpc_layer=RPC_PROTOCOL, 179 args=(attempts,), 180 auto_restart=True) 181 mpr.start() 182 results = mpr.join(timeout=90).return_value 183 self.assertAllEqual(results, [[2.], [2.]]) 184 185 def test_quick_recover(self): 186 # This test simulates the case when a worker fails but recovers quickly 187 # before the next collective. 188 # 189 # It's not guaranteed that the cluster only restarts once when one worker 190 # fails. The external job management system is expected to keep restarting 191 # failed workers. 192 193 def worker_fn(attempts): 194 # Set a long check alive interval to better simulate the case when a 195 # worker fails and recovers during a check alive interval. 196 mwms_lib.CollectiveAllReduceExtended._check_alive_interval = 30 197 mwms_lib.CollectiveAllReduceExtended._check_alive_initial_timeout = 30 198 199 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 200 task_id, attempt = get_attempt(strategy, attempts) 201 202 @tf.function 203 def replica_fn(): 204 ctx = tf.distribute.get_replica_context() 205 # Use a large tensor because small tensor may hang regardless when the 206 # worker recovers. 207 value = tf.ones((64, 64)) 208 ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value]) 209 210 strategy.run(replica_fn) 211 # worker-1 dies here. 212 if attempt == 1 and task_id == 1: 213 quick_exit(1) 214 strategy.run(replica_fn) 215 216 cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) 217 attempts = multi_process_runner.manager().dict() 218 mpr = multi_process_runner.MultiProcessRunner( 219 worker_fn, 220 cluster_spec, 221 rpc_layer=RPC_PROTOCOL, 222 args=(attempts,), 223 auto_restart=True) 224 mpr.start() 225 mpr.join(timeout=90) 226 227 228if __name__ == "__main__": 229 test_util.main() 230