• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This file contains tests that simulate peer failures.
16
17When a peer fails during MultiWorkerMirroredStrategy training. All workers
18should get Unavailable error.
19"""
20
21import os
22
23import tensorflow as tf
24
25from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
26from tensorflow.python.distribute import multi_process_runner
27from tensorflow.python.distribute import multi_worker_test_base
28from tensorflow.python.distribute import test_util
29from tensorflow.python.eager import test
30
31RPC_PROTOCOL = "grpc"
32
33# Put it in top level so it executes in the child processes as well.
34mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
35mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
36mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
37# This is needed for OSS, which issues all RPCs with fail_fast=false by default.
38mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1
39
40
41def get_attempt(strategy, attempts):
42  task_type = strategy.cluster_resolver.task_type
43  task_id = strategy.cluster_resolver.task_id
44  attempts[(task_type, task_id)] = attempts.get((task_type, task_id), 0) + 1
45  return task_id, attempts[(task_type, task_id)]
46
47
48quick_exit = os._exit  # pylint: disable=protected-access
49
50
51class PeerFailureTest(test.TestCase):
52  # Note that all the tests use auto_restart=True. Currently we rely on the
53  # assumption that an external system restarts failed tasks. If the assumption
54  # is not true, the remaining tasks may still hang instead of fail.
55  #
56  # In these tests we leverage the auto restart feature of MultiProcessRunner.
57  # Failed workers are restarted automatically. In reality there needs to be
58  # some job management system that does the restart, e.g. Kubernetes.
59  #
60  # Worker failures may cause problems if there're more than one collective, and
61  # the failure happens after the first collective. In this case the recovered
62  # worker will be running a different collective with the rest, which causes a
63  # deadlock. Note that collectives are common, e.g. when creating variables the
64  # initial values are broadcasted from the first worker.
65  #
66  # We use a multiprocessing.Manager().dict() object to track the attempts of
67  # each worker. We take different actions in different attempts to simuate the
68  # events in real world. E.g. some tests make a worker fail on the first
69  # attempt only, and asserts that it should recovery.
70
71  def test_creating_variable(self):
72    # This test simulates the case when a worker fails before or during creating
73    # a variable. Creating variables involve broadcasting the initial value from
74    # the first replica to all replicas.
75
76    def worker_fn():
77      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
78      with strategy.scope():
79        tf.Variable(1.)
80        # worker-1 dies here.
81        if strategy.cluster_resolver.task_id == 1:
82          quick_exit(1)
83        v = tf.Variable(tf.random.uniform(()))
84        return v.read_value().numpy()
85
86    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
87    mpr = multi_process_runner.MultiProcessRunner(
88        worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL)
89    mpr.start()
90    # TODO(b/151232436): Always raise UnavailableError when a peer fails.
91    with self.assertRaises(
92        (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)):
93      mpr.join(timeout=60)
94
95  def test_reduce_small_tensor(self):
96    # This test simulates the case when a worker fails before or during reducing
97    # a small tensors, e.g. reading a metric.
98    #
99    # Note that this is written for a specific corner case that used to happen
100    # only when all of the following conditions are met:
101    #   - There're two workers.
102    #   - They're reducing a small tensor. The definition of small varies
103    #     per platform.
104    #   - They're reducing a single tensor. Batched all-reduce are not affected.
105    #   - It must be worker-1 that fails.
106    # Under this case, the all-reduce is effectively two send/recv operation,
107    # the first one from worker-0 to worker-1, and the second one vice versa.
108    # The first one blocks the second one. In send/recv, the sending party is
109    # not aware of the failures of the receiving party.
110
111    def worker_fn():
112      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
113      value = tf.identity([1.])
114      strategy.reduce("sum", value, axis=None)
115      # worker-1 dies here.
116      if strategy.cluster_resolver.task_id == 1:
117        quick_exit(1)
118      strategy.reduce("sum", value, axis=None)
119
120    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
121    mpr = multi_process_runner.MultiProcessRunner(
122        worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL)
123    mpr.start()
124    # TODO(b/151232436): Always raise UnavailableError when a peer fails.
125    with self.assertRaises(
126        (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)):
127      mpr.join(timeout=60)
128
129
130class PeerFailureRecoverTest(test.TestCase):
131  # Similar to PeerFailureTest but simulates the situation where there's some
132  # external system that automatically restarts failed workers.
133
134  def test_creating_variable(self):
135    # See PeerFailureTest.test_creating_variable
136
137    def worker_fn(attempts):
138      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
139      task_id, attempt = get_attempt(strategy, attempts)
140      with strategy.scope():
141        tf.Variable(1.)
142        # worker-1 dies here.
143        if attempt == 1 and task_id == 1:
144          quick_exit(1)
145        v = tf.Variable(tf.random.uniform(()))
146        return v.read_value().numpy()
147
148    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
149    attempts = multi_process_runner.manager().dict()
150    mpr = multi_process_runner.MultiProcessRunner(
151        worker_fn,
152        cluster_spec,
153        rpc_layer=RPC_PROTOCOL,
154        args=(attempts,),
155        auto_restart=True)
156    mpr.start()
157    results = mpr.join(timeout=90).return_value
158    self.assertEqual(results[0], results[1])
159
160  def test_reduce_small_tensor(self):
161    # See PeerFailureTest.test_reduce_small_tensor
162
163    def worker_fn(attempts):
164      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
165      task_id, attempt = get_attempt(strategy, attempts)
166      value = tf.identity([1.])
167      strategy.reduce("sum", value, axis=None)
168      # worker-1 dies here.
169      if attempt == 1 and task_id == 1:
170        quick_exit(1)
171      return strategy.reduce("sum", value, axis=None).numpy()
172
173    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
174    attempts = multi_process_runner.manager().dict()
175    mpr = multi_process_runner.MultiProcessRunner(
176        worker_fn,
177        cluster_spec,
178        rpc_layer=RPC_PROTOCOL,
179        args=(attempts,),
180        auto_restart=True)
181    mpr.start()
182    results = mpr.join(timeout=90).return_value
183    self.assertAllEqual(results, [[2.], [2.]])
184
185  def test_quick_recover(self):
186    # This test simulates the case when a worker fails but recovers quickly
187    # before the next collective.
188    #
189    # It's not guaranteed that the cluster only restarts once when one worker
190    # fails. The external job management system is expected to keep restarting
191    # failed workers.
192
193    def worker_fn(attempts):
194      # Set a long check alive interval to better simulate the case when a
195      # worker fails and recovers during a check alive interval.
196      mwms_lib.CollectiveAllReduceExtended._check_alive_interval = 30
197      mwms_lib.CollectiveAllReduceExtended._check_alive_initial_timeout = 30
198
199      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
200      task_id, attempt = get_attempt(strategy, attempts)
201
202      @tf.function
203      def replica_fn():
204        ctx = tf.distribute.get_replica_context()
205        # Use a large tensor because small tensor may hang regardless when the
206        # worker recovers.
207        value = tf.ones((64, 64))
208        ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value])
209
210      strategy.run(replica_fn)
211      # worker-1 dies here.
212      if attempt == 1 and task_id == 1:
213        quick_exit(1)
214      strategy.run(replica_fn)
215
216    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
217    attempts = multi_process_runner.manager().dict()
218    mpr = multi_process_runner.MultiProcessRunner(
219        worker_fn,
220        cluster_spec,
221        rpc_layer=RPC_PROTOCOL,
222        args=(attempts,),
223        auto_restart=True)
224    mpr.start()
225    mpr.join(timeout=90)
226
227
228if __name__ == "__main__":
229  test_util.main()
230