• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import functools
23import threading
24import weakref
25
26from tensorflow.python import pywrap_tfe
27from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
28from tensorflow.python.autograph.impl import api as autograph
29from tensorflow.python.distribute import distribute_lib
30from tensorflow.python.distribute import distribute_utils
31from tensorflow.python.distribute import shared_variable_creator
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import device as tf_device
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import summary_ops_v2
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.training import coordinator
40
41
42def _is_gpu_device(device):
43  return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
44
45
46def call_for_each_replica(strategy, fn, args=None, kwargs=None):
47  """Call `fn` on each worker devices(replica).
48
49  It's highly recommended to wrap the call to this function inside a
50  `tf.function`, otherwise the performance is poor.
51
52  Args:
53    strategy: `tf.distribute.Strategy`.
54    fn: function to call on each worker devices.
55    args: positional arguments to `fn`.
56    kwargs: keyword arguments to `fn`.
57
58  Returns:
59    Wrapped returned value of `fn` from all replicas.
60  """
61  if args is None:
62    args = ()
63  if kwargs is None:
64    kwargs = {}
65
66  if isinstance(fn, def_function.Function):
67    # Don't lift up the tf.function decoration if `fn` is compiled with XLA
68    # and all devices are GPU. In this case we will use collectives to do
69    # cross-device communication, thus no merge_call is in the path.
70    if fn._jit_compile and all(  # pylint: disable=protected-access
71        [_is_gpu_device(d) for d in strategy.extended.worker_devices]):
72      return _call_for_each_replica(strategy, fn, args, kwargs)
73
74    if strategy not in _cfer_fn_cache:
75      _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
76    wrapped = _cfer_fn_cache[strategy].get(fn)
77    if wrapped is None:
78      # We need to wrap fn such that it triggers _call_for_each_replica inside
79      # the tf.function. We use _clone() instead of @tf.function wrapped
80      # call_for_each_replica() because we would like to retain the arguments to
81      # the @tf.function decorator of fn.
82      wrapped = fn._clone(  # pylint: disable=protected-access
83          python_function=functools.partial(call_for_each_replica, strategy,
84                                            fn.python_function))
85      _cfer_fn_cache[strategy][fn] = wrapped
86    return wrapped(args, kwargs)
87
88  if context.executing_eagerly():
89    logging.log_first_n(
90        logging.WARN, "Using %s eagerly has significant "
91        "overhead currently. We will be working on improving "
92        "this in the future, but for now please wrap "
93        "`call_for_each_replica` or `experimental_run` or "
94        "`run` inside a tf.function to get "
95        "the best performance." % strategy.__class__.__name__, 5)
96  else:
97    # When a tf.function is wrapped to trigger _call_for_each_replica (see
98    # the other branch above), AutoGraph stops conversion at
99    # _call_for_each_replica itself (TF library functions are allowlisted).
100    # This makes sure that the Python function that originally passed to
101    # the tf.function is still converted.
102    fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
103
104  return _call_for_each_replica(strategy, fn, args, kwargs)
105
106
107# Per strategy cache for call_for_each_replica def_function.Function objects.
108_cfer_fn_cache = weakref.WeakKeyDictionary()
109
110
111@contextlib.contextmanager
112def _enter_graph(g, eager, creator_stack=None):
113  """Context manager for selecting a graph and maybe eager mode."""
114  if eager:
115    with g.as_default(), context.eager_mode():
116      if creator_stack is not None:
117        g._variable_creator_stack = creator_stack  # pylint: disable=protected-access
118      yield
119  else:
120    with g.as_default():
121      if creator_stack is not None:
122        g._variable_creator_stack = creator_stack  # pylint: disable=protected-access
123      yield
124
125
126def _cpu_device(device):
127  cpu_device = tf_device.DeviceSpec.from_string(device)
128  cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
129  return cpu_device.to_string()
130
131
132class _RequestedStop(Exception):  # pylint: disable=g-bad-exception-name
133  pass
134
135
136def _call_for_each_replica(distribution, fn, args, kwargs):
137  """Run `fn` in separate threads, once per replica/worker device.
138
139  Args:
140    distribution: the DistributionStrategy object.
141    fn: function to run (will be run once per replica, each in its own thread).
142    args: positional arguments for `fn`
143    kwargs: keyword arguments for `fn`.
144
145  Returns:
146    Merged return value of `fn` across all replicas.
147
148  Raises:
149    RuntimeError: If fn() calls get_replica_context().merge_call() a different
150        number of times from the available devices.
151  """
152  # TODO(josh11b): Add this option once we add synchronization to variable
153  # creation. Until then, this is pretty unsafe to use.
154  run_concurrently = False
155  if not context.executing_eagerly():
156    # Needed for per-thread device, etc. contexts in graph mode.
157    ops.get_default_graph().switch_to_thread_local()
158
159  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
160
161  shared_variable_store = {}
162  devices = distribution.extended.worker_devices
163
164  # TODO(isaprykin): Create these threads once instead of during every call.
165  threads = []
166  for index in range(len(devices)):
167    variable_creator_fn = shared_variable_creator.make_fn(
168        shared_variable_store, index)
169    t = _MirroredReplicaThread(distribution, coord, index, devices,
170                               variable_creator_fn, fn,
171                               distribute_utils.caching_scope_local,
172                               distribute_utils.select_replica(index, args),
173                               distribute_utils.select_replica(index, kwargs))
174    threads.append(t)
175
176  for t in threads:
177    t.start()
178
179  # When `fn` starts `should_run` event is set on _MirroredReplicaThread
180  # (`MRT`) threads. The execution waits until
181  # `MRT.has_paused` is set, which indicates that either `fn` is
182  # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
183  # complete, then `MRT.done` is set to True.  Otherwise, arguments
184  # of `get_replica_context().merge_call` from all paused threads are grouped
185  # and the `merge_fn` is performed.  Results of the
186  # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
187  # Each such `get_replica_context().merge_call` call returns the
188  # `MRT.merge_result` for that thread when `MRT.should_run` event
189  # is reset again. Execution of `fn` resumes.
190
191  try:
192    with coord.stop_on_exception():
193      all_done = False
194      while not all_done and not coord.should_stop():
195        done = []
196        if run_concurrently:
197          for t in threads:
198            t.should_run.set()
199          for t in threads:
200            t.has_paused.wait()
201            t.has_paused.clear()
202            if coord.should_stop():
203              return None
204            done.append(t.done)
205        else:
206          for t in threads:
207            t.should_run.set()
208            t.has_paused.wait()
209            t.has_paused.clear()
210            if coord.should_stop():
211              return None
212            done.append(t.done)
213        if coord.should_stop():
214          return None
215        all_done = all(done)
216        if not all_done:
217          if any(done):
218            raise RuntimeError("Some replicas made a different number of "
219                               "replica_context().merge_call() calls.")
220          # get_replica_context().merge_call() case
221          merge_args = distribute_utils.regroup(
222              tuple(t.merge_args for t in threads))
223          merge_kwargs = distribute_utils.regroup(
224              tuple(t.merge_kwargs for t in threads))
225          # We capture the name_scope of the MRT when we call merge_fn
226          # to ensure that if we have opened a name scope in the MRT,
227          # it will be respected when executing the merge function. We only
228          # capture the name_scope from the first MRT and assume it is
229          # the same for all other MRTs.
230          mtt_captured_name_scope = threads[0].captured_name_scope
231          mtt_captured_var_scope = threads[0].captured_var_scope
232          # Capture and merge the control dependencies from all the threads.
233          mtt_captured_control_deps = set()
234          for t in threads:
235            mtt_captured_control_deps.update(t.captured_control_deps)
236          with ops.name_scope(mtt_captured_name_scope),\
237              ops.control_dependencies(mtt_captured_control_deps), \
238              variable_scope.variable_scope(mtt_captured_var_scope):
239            merge_result = threads[0].merge_fn(distribution, *merge_args,
240                                               **merge_kwargs)
241          for r, t in enumerate(threads):
242            t.merge_result = distribute_utils.select_replica(r, merge_result)
243  finally:
244    for t in threads:
245      t.should_run.set()
246    coord.join(threads)
247
248  return distribute_utils.regroup(tuple(t.main_result for t in threads))
249
250
251class _MirroredReplicaThread(threading.Thread):
252  """A thread that runs() a function on a device."""
253
254  def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn,
255               caching_scope, args, kwargs):
256    super(_MirroredReplicaThread, self).__init__()
257    self.coord = coord
258    self.distribution = dist
259    self.devices = devices
260    self.replica_id = replica_id
261    self.replica_id_in_sync_group = (
262        dist.extended._get_replica_id_in_sync_group(replica_id))  # pylint: disable=protected-access
263
264    self.variable_creator_fn = variable_creator_fn
265    # State needed to run and return the results of `fn`.
266    self.main_fn = fn
267    self.main_args = args
268    self.main_kwargs = kwargs
269    self.main_result = None
270    self.done = False
271    # State needed to run the next merge_call() (if any) requested via
272    # ReplicaContext.
273    self.merge_fn = None
274    self.merge_args = None
275    self.merge_kwargs = None
276    self.merge_result = None
277    self.captured_name_scope = None
278    self.captured_var_scope = None
279    try:
280      self.caching_scope_entered = caching_scope.new_cache_scope_count
281      self.caching_scope_exited = caching_scope.cache_scope_exited_count
282    except AttributeError:
283      self.caching_scope_entered = None
284      self.caching_scope_exited = None
285
286    # We use a thread.Event for the main thread to signal when this
287    # thread should start running (`should_run`), and another for
288    # this thread to transfer control back to the main thread
289    # (`has_paused`, either when it gets to a
290    # `get_replica_context().merge_call` or when `fn` returns). In
291    # either case the event starts cleared, is signaled by calling
292    # set(). The receiving thread waits for the signal by calling
293    # wait() and then immediately clearing the event using clear().
294    self.should_run = threading.Event()
295    self.has_paused = threading.Event()
296    # These fields have to do with inheriting various contexts from the
297    # parent thread:
298    context.ensure_initialized()
299    ctx = context.context()
300    self.in_eager = ctx.executing_eagerly()
301    self.record_thread_local_summary_state()
302    self.record_thread_local_eager_context_state()
303    self.context_device_policy = (
304        pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
305            ctx._context_handle))  # pylint: disable=protected-access
306    self.graph = ops.get_default_graph()
307    with ops.init_scope():
308      self._init_in_eager = context.executing_eagerly()
309      self._init_graph = ops.get_default_graph()
310    self._variable_creator_stack = self.graph._variable_creator_stack[:]  # pylint: disable=protected-access
311    self._var_scope = variable_scope.get_variable_scope()
312    # Adding a "/" at end lets us re-enter this scope later.
313    self._name_scope = self.graph.get_name_scope()
314    if self._name_scope:
315      self._name_scope += "/"
316    if self.replica_id > 0:
317      if not self._name_scope:
318        self._name_scope = ""
319      self._name_scope += "replica_%d/" % self.replica_id
320
321  def run(self):
322    self.should_run.wait()
323    self.should_run.clear()
324    try:
325      if self.coord.should_stop():
326        return
327      self.restore_thread_local_summary_state()
328      self.restore_thread_local_eager_context_state()
329      if (self.caching_scope_entered is not None and
330          self.caching_scope_exited is not None):
331        distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered
332        distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited
333      # TODO(josh11b): Use current logical device instead of 0 here.
334      with self.coord.stop_on_exception(), \
335          _enter_graph(self._init_graph, self._init_in_eager), \
336          _enter_graph(self.graph, self.in_eager,
337                       self._variable_creator_stack), \
338          context.device_policy(self.context_device_policy), \
339          _MirroredReplicaContext(self.distribution,
340                                  self.replica_id_in_sync_group), \
341          ops.device(self.devices[self.replica_id]), \
342          ops.name_scope(self._name_scope), \
343          variable_scope.variable_scope(
344              self._var_scope, reuse=self.replica_id > 0), \
345          variable_scope.variable_creator_scope(self.variable_creator_fn):
346        self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
347        self.done = True
348    finally:
349      self.has_paused.set()
350
351  def record_thread_local_summary_state(self):
352    """Record the thread local summary state in self."""
353    # TODO(slebedev): is this still relevant? the referenced bug is closed.
354    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
355    self._summary_step = summary_state.step
356    self._summary_writer = summary_state.writer
357    self._summary_recording = summary_state.is_recording
358    self._summary_recording_distribution_strategy = (
359        summary_state.is_recording_distribution_strategy)
360
361  def restore_thread_local_summary_state(self):
362    """Restore thread local summary state from self."""
363    # TODO(slebedev): is this still relevant? the referenced bug is closed.
364    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
365    summary_state.step = self._summary_step
366    summary_state.writer = self._summary_writer
367    summary_state.is_recording = self._summary_recording
368    summary_state.is_recording_distribution_strategy = (
369        self._summary_recording_distribution_strategy)
370
371  def record_thread_local_eager_context_state(self):
372    ctx = context.context()
373    eager_context_state = ctx._thread_local_data  # pylint: disable=protected-access
374    self._eager_context_op_callbacks = eager_context_state.op_callbacks
375    # TODO(b/125892694): record other fields in EagerContext.
376
377  def restore_thread_local_eager_context_state(self):
378    ctx = context.context()
379    eager_context_state = ctx._thread_local_data  # pylint: disable=protected-access
380    eager_context_state.op_callbacks = self._eager_context_op_callbacks
381    # TODO(b/125892694): record other fields in EagerContext.
382
383
384class _MirroredReplicaContext(distribute_lib.ReplicaContext):
385  """ReplicaContext for synchronized replica."""
386
387  def _merge_call(self, fn, args, kwargs):
388    """`merge_call()` implementation for synchronized replica.
389
390    This pauses the current replica thread and passes `fn` and its arguments to
391    the main thread. The main thread will wait until all replicas pause, then
392    invoke `fn` with grouped arguments. The current replica thread will continue
393    after `fn` completes.
394
395    See `_call_for_each_replica` for the logic in the main thread.
396
397    Args:
398      fn: a function that is called in cross replica context with grouped
399        arguments from each replica. `fn` should returns grouped values.
400      args: positional arguments to `fn`.
401      kwargs: keyward arguments to `fn`.
402
403    Returns:
404      Return value of `fn` for the current replica.
405
406    Raises:
407      RuntimeError: when merge_call happens in a different graph, e.g. in a
408        different tf.function, which is not supported now.
409      _RequestedStop: when stop is requested.
410
411    """
412    t = threading.current_thread()
413    assert isinstance(t, _MirroredReplicaThread)
414    t.merge_fn = fn
415    t.merge_args = args
416    t.merge_kwargs = kwargs
417    t.captured_name_scope = t.graph.get_name_scope()
418    # Adding a "/" at end lets us re-enter this scope later.
419    if t.captured_name_scope:
420      t.captured_name_scope += "/"
421
422    t.captured_var_scope = variable_scope.get_variable_scope()
423    t.captured_control_deps = t.graph._current_control_dependencies()  # pylint: disable=protected-access
424
425    # It is problematic if `merge_call` is called under a different graph other
426    # than the one that `_call_for_each_replica` is called under, there are
427    # 3 cases this can happen:
428    #
429    #   1. The `fn` passed to `_call_for_each_replica` is decorated with
430    #   `tf.function` and there is a `merge_call` in `fn`. Since
431    #   MirroredStrategy traces a separate function per thread (per device),
432    #   and each trace takes a shared lock, the lock is never released by the
433    #   first thread and subsequent replica threads cannot proceed to trace
434    #   their own functions. This issue is addressed by always converting
435    #   `_call_for_each_replica(tf.function(f))` to
436    #   ``tf.function(_call_for_each_replica(f))`.` in
437    #   `MirroredStrategy._call_for_each_replica`.
438    #
439    #   2. The `fn` passed to `_call_for_each_replica` contains a nested
440    #   `tf.function`, and there is a `merge_call` in the nested `tf.function`.
441    #   In this case each thread can successfully trace its own function, but
442    #   since the `merge_fn` passed to `merge_call` is executed in the main
443    #   thread (where `_call_for_each_replica` is executed), it can't access
444    #   the tensors that come from different graphs.
445    #
446    #   3. The `fn` passed to `_call_for_each_replica` contains a control-flow
447    #   statement, and there is a `merge_call` inside the control-flow body,
448    #   `fn` or `_call_for_each_replica` is decorated with `tf.function`.
449    #   Control flow statement creates a separate graph for its body, similar
450    #   to #2, `merge_fn` executed in the main thread can't access the
451    #   tensors that come from different graphs.
452    #
453    #   We raise an error for #2 and #3.
454    if ops.get_default_graph() != t.graph:
455      raise RuntimeError(
456          "`merge_call` called while defining a new graph or a tf.function."
457          " This can often happen if the function `fn` passed to"
458          " `strategy.run()` contains a nested `@tf.function`, and the nested "
459          "`@tf.function` contains a synchronization point, such as aggregating"
460          " gradients (e.g, optimizer.apply_gradients), or if the function `fn`"
461          " uses a control flow statement which contains a synchronization"
462          " point in the body. Such behaviors are not yet supported. Instead,"
463          " please avoid nested `tf.function`s or control flow statements that"
464          " may potentially cross a synchronization boundary, for example,"
465          " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`"
466          " inside a `tf.function` or move the control flow out of `fn`. If"
467          " you are subclassing a `tf.keras.Model`, please avoid decorating"
468          " overridden methods `test_step` and `train_step` in `tf.function`.")
469
470    t.has_paused.set()
471    t.should_run.wait()
472    t.should_run.clear()
473    if t.coord.should_stop():
474      raise _RequestedStop()
475    return t.merge_result
476
477  @property
478  def devices(self):
479    distribute_lib.require_replica_context(self)
480    return [
481        self._strategy.extended.worker_devices_by_replica[
482            self._replica_id_in_sync_group]
483    ]
484