1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import functools 23import threading 24import weakref 25 26from tensorflow.python import pywrap_tfe 27from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 28from tensorflow.python.autograph.impl import api as autograph 29from tensorflow.python.distribute import distribute_lib 30from tensorflow.python.distribute import distribute_utils 31from tensorflow.python.distribute import shared_variable_creator 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import device as tf_device 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import summary_ops_v2 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.training import coordinator 40 41 42def _is_gpu_device(device): 43 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 44 45 46def call_for_each_replica(strategy, fn, args=None, kwargs=None): 47 """Call `fn` on each worker devices(replica). 48 49 It's highly recommended to wrap the call to this function inside a 50 `tf.function`, otherwise the performance is poor. 51 52 Args: 53 strategy: `tf.distribute.Strategy`. 54 fn: function to call on each worker devices. 55 args: positional arguments to `fn`. 56 kwargs: keyword arguments to `fn`. 57 58 Returns: 59 Wrapped returned value of `fn` from all replicas. 60 """ 61 if args is None: 62 args = () 63 if kwargs is None: 64 kwargs = {} 65 66 if isinstance(fn, def_function.Function): 67 # Don't lift up the tf.function decoration if `fn` is compiled with XLA 68 # and all devices are GPU. In this case we will use collectives to do 69 # cross-device communication, thus no merge_call is in the path. 70 if fn._jit_compile and all( # pylint: disable=protected-access 71 [_is_gpu_device(d) for d in strategy.extended.worker_devices]): 72 return _call_for_each_replica(strategy, fn, args, kwargs) 73 74 if strategy not in _cfer_fn_cache: 75 _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary() 76 wrapped = _cfer_fn_cache[strategy].get(fn) 77 if wrapped is None: 78 # We need to wrap fn such that it triggers _call_for_each_replica inside 79 # the tf.function. We use _clone() instead of @tf.function wrapped 80 # call_for_each_replica() because we would like to retain the arguments to 81 # the @tf.function decorator of fn. 82 wrapped = fn._clone( # pylint: disable=protected-access 83 python_function=functools.partial(call_for_each_replica, strategy, 84 fn.python_function)) 85 _cfer_fn_cache[strategy][fn] = wrapped 86 return wrapped(args, kwargs) 87 88 if context.executing_eagerly(): 89 logging.log_first_n( 90 logging.WARN, "Using %s eagerly has significant " 91 "overhead currently. We will be working on improving " 92 "this in the future, but for now please wrap " 93 "`call_for_each_replica` or `experimental_run` or " 94 "`run` inside a tf.function to get " 95 "the best performance." % strategy.__class__.__name__, 5) 96 else: 97 # When a tf.function is wrapped to trigger _call_for_each_replica (see 98 # the other branch above), AutoGraph stops conversion at 99 # _call_for_each_replica itself (TF library functions are allowlisted). 100 # This makes sure that the Python function that originally passed to 101 # the tf.function is still converted. 102 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 103 104 return _call_for_each_replica(strategy, fn, args, kwargs) 105 106 107# Per strategy cache for call_for_each_replica def_function.Function objects. 108_cfer_fn_cache = weakref.WeakKeyDictionary() 109 110 111@contextlib.contextmanager 112def _enter_graph(g, eager, creator_stack=None): 113 """Context manager for selecting a graph and maybe eager mode.""" 114 if eager: 115 with g.as_default(), context.eager_mode(): 116 if creator_stack is not None: 117 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 118 yield 119 else: 120 with g.as_default(): 121 if creator_stack is not None: 122 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 123 yield 124 125 126def _cpu_device(device): 127 cpu_device = tf_device.DeviceSpec.from_string(device) 128 cpu_device = cpu_device.replace(device_type="CPU", device_index=0) 129 return cpu_device.to_string() 130 131 132class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name 133 pass 134 135 136def _call_for_each_replica(distribution, fn, args, kwargs): 137 """Run `fn` in separate threads, once per replica/worker device. 138 139 Args: 140 distribution: the DistributionStrategy object. 141 fn: function to run (will be run once per replica, each in its own thread). 142 args: positional arguments for `fn` 143 kwargs: keyword arguments for `fn`. 144 145 Returns: 146 Merged return value of `fn` across all replicas. 147 148 Raises: 149 RuntimeError: If fn() calls get_replica_context().merge_call() a different 150 number of times from the available devices. 151 """ 152 # TODO(josh11b): Add this option once we add synchronization to variable 153 # creation. Until then, this is pretty unsafe to use. 154 run_concurrently = False 155 if not context.executing_eagerly(): 156 # Needed for per-thread device, etc. contexts in graph mode. 157 ops.get_default_graph().switch_to_thread_local() 158 159 coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) 160 161 shared_variable_store = {} 162 devices = distribution.extended.worker_devices 163 164 # TODO(isaprykin): Create these threads once instead of during every call. 165 threads = [] 166 for index in range(len(devices)): 167 variable_creator_fn = shared_variable_creator.make_fn( 168 shared_variable_store, index) 169 t = _MirroredReplicaThread(distribution, coord, index, devices, 170 variable_creator_fn, fn, 171 distribute_utils.caching_scope_local, 172 distribute_utils.select_replica(index, args), 173 distribute_utils.select_replica(index, kwargs)) 174 threads.append(t) 175 176 for t in threads: 177 t.start() 178 179 # When `fn` starts `should_run` event is set on _MirroredReplicaThread 180 # (`MRT`) threads. The execution waits until 181 # `MRT.has_paused` is set, which indicates that either `fn` is 182 # complete or a `get_replica_context().merge_call()` is called. If `fn` is 183 # complete, then `MRT.done` is set to True. Otherwise, arguments 184 # of `get_replica_context().merge_call` from all paused threads are grouped 185 # and the `merge_fn` is performed. Results of the 186 # `get_replica_context().merge_call` are then set to `MRT.merge_result`. 187 # Each such `get_replica_context().merge_call` call returns the 188 # `MRT.merge_result` for that thread when `MRT.should_run` event 189 # is reset again. Execution of `fn` resumes. 190 191 try: 192 with coord.stop_on_exception(): 193 all_done = False 194 while not all_done and not coord.should_stop(): 195 done = [] 196 if run_concurrently: 197 for t in threads: 198 t.should_run.set() 199 for t in threads: 200 t.has_paused.wait() 201 t.has_paused.clear() 202 if coord.should_stop(): 203 return None 204 done.append(t.done) 205 else: 206 for t in threads: 207 t.should_run.set() 208 t.has_paused.wait() 209 t.has_paused.clear() 210 if coord.should_stop(): 211 return None 212 done.append(t.done) 213 if coord.should_stop(): 214 return None 215 all_done = all(done) 216 if not all_done: 217 if any(done): 218 raise RuntimeError("Some replicas made a different number of " 219 "replica_context().merge_call() calls.") 220 # get_replica_context().merge_call() case 221 merge_args = distribute_utils.regroup( 222 tuple(t.merge_args for t in threads)) 223 merge_kwargs = distribute_utils.regroup( 224 tuple(t.merge_kwargs for t in threads)) 225 # We capture the name_scope of the MRT when we call merge_fn 226 # to ensure that if we have opened a name scope in the MRT, 227 # it will be respected when executing the merge function. We only 228 # capture the name_scope from the first MRT and assume it is 229 # the same for all other MRTs. 230 mtt_captured_name_scope = threads[0].captured_name_scope 231 mtt_captured_var_scope = threads[0].captured_var_scope 232 # Capture and merge the control dependencies from all the threads. 233 mtt_captured_control_deps = set() 234 for t in threads: 235 mtt_captured_control_deps.update(t.captured_control_deps) 236 with ops.name_scope(mtt_captured_name_scope),\ 237 ops.control_dependencies(mtt_captured_control_deps), \ 238 variable_scope.variable_scope(mtt_captured_var_scope): 239 merge_result = threads[0].merge_fn(distribution, *merge_args, 240 **merge_kwargs) 241 for r, t in enumerate(threads): 242 t.merge_result = distribute_utils.select_replica(r, merge_result) 243 finally: 244 for t in threads: 245 t.should_run.set() 246 coord.join(threads) 247 248 return distribute_utils.regroup(tuple(t.main_result for t in threads)) 249 250 251class _MirroredReplicaThread(threading.Thread): 252 """A thread that runs() a function on a device.""" 253 254 def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, fn, 255 caching_scope, args, kwargs): 256 super(_MirroredReplicaThread, self).__init__() 257 self.coord = coord 258 self.distribution = dist 259 self.devices = devices 260 self.replica_id = replica_id 261 self.replica_id_in_sync_group = ( 262 dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access 263 264 self.variable_creator_fn = variable_creator_fn 265 # State needed to run and return the results of `fn`. 266 self.main_fn = fn 267 self.main_args = args 268 self.main_kwargs = kwargs 269 self.main_result = None 270 self.done = False 271 # State needed to run the next merge_call() (if any) requested via 272 # ReplicaContext. 273 self.merge_fn = None 274 self.merge_args = None 275 self.merge_kwargs = None 276 self.merge_result = None 277 self.captured_name_scope = None 278 self.captured_var_scope = None 279 try: 280 self.caching_scope_entered = caching_scope.new_cache_scope_count 281 self.caching_scope_exited = caching_scope.cache_scope_exited_count 282 except AttributeError: 283 self.caching_scope_entered = None 284 self.caching_scope_exited = None 285 286 # We use a thread.Event for the main thread to signal when this 287 # thread should start running (`should_run`), and another for 288 # this thread to transfer control back to the main thread 289 # (`has_paused`, either when it gets to a 290 # `get_replica_context().merge_call` or when `fn` returns). In 291 # either case the event starts cleared, is signaled by calling 292 # set(). The receiving thread waits for the signal by calling 293 # wait() and then immediately clearing the event using clear(). 294 self.should_run = threading.Event() 295 self.has_paused = threading.Event() 296 # These fields have to do with inheriting various contexts from the 297 # parent thread: 298 context.ensure_initialized() 299 ctx = context.context() 300 self.in_eager = ctx.executing_eagerly() 301 self.record_thread_local_summary_state() 302 self.record_thread_local_eager_context_state() 303 self.context_device_policy = ( 304 pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( 305 ctx._context_handle)) # pylint: disable=protected-access 306 self.graph = ops.get_default_graph() 307 with ops.init_scope(): 308 self._init_in_eager = context.executing_eagerly() 309 self._init_graph = ops.get_default_graph() 310 self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access 311 self._var_scope = variable_scope.get_variable_scope() 312 # Adding a "/" at end lets us re-enter this scope later. 313 self._name_scope = self.graph.get_name_scope() 314 if self._name_scope: 315 self._name_scope += "/" 316 if self.replica_id > 0: 317 if not self._name_scope: 318 self._name_scope = "" 319 self._name_scope += "replica_%d/" % self.replica_id 320 321 def run(self): 322 self.should_run.wait() 323 self.should_run.clear() 324 try: 325 if self.coord.should_stop(): 326 return 327 self.restore_thread_local_summary_state() 328 self.restore_thread_local_eager_context_state() 329 if (self.caching_scope_entered is not None and 330 self.caching_scope_exited is not None): 331 distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered 332 distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited 333 # TODO(josh11b): Use current logical device instead of 0 here. 334 with self.coord.stop_on_exception(), \ 335 _enter_graph(self._init_graph, self._init_in_eager), \ 336 _enter_graph(self.graph, self.in_eager, 337 self._variable_creator_stack), \ 338 context.device_policy(self.context_device_policy), \ 339 _MirroredReplicaContext(self.distribution, 340 self.replica_id_in_sync_group), \ 341 ops.device(self.devices[self.replica_id]), \ 342 ops.name_scope(self._name_scope), \ 343 variable_scope.variable_scope( 344 self._var_scope, reuse=self.replica_id > 0), \ 345 variable_scope.variable_creator_scope(self.variable_creator_fn): 346 self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) 347 self.done = True 348 finally: 349 self.has_paused.set() 350 351 def record_thread_local_summary_state(self): 352 """Record the thread local summary state in self.""" 353 # TODO(slebedev): is this still relevant? the referenced bug is closed. 354 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 355 self._summary_step = summary_state.step 356 self._summary_writer = summary_state.writer 357 self._summary_recording = summary_state.is_recording 358 self._summary_recording_distribution_strategy = ( 359 summary_state.is_recording_distribution_strategy) 360 361 def restore_thread_local_summary_state(self): 362 """Restore thread local summary state from self.""" 363 # TODO(slebedev): is this still relevant? the referenced bug is closed. 364 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 365 summary_state.step = self._summary_step 366 summary_state.writer = self._summary_writer 367 summary_state.is_recording = self._summary_recording 368 summary_state.is_recording_distribution_strategy = ( 369 self._summary_recording_distribution_strategy) 370 371 def record_thread_local_eager_context_state(self): 372 ctx = context.context() 373 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 374 self._eager_context_op_callbacks = eager_context_state.op_callbacks 375 # TODO(b/125892694): record other fields in EagerContext. 376 377 def restore_thread_local_eager_context_state(self): 378 ctx = context.context() 379 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 380 eager_context_state.op_callbacks = self._eager_context_op_callbacks 381 # TODO(b/125892694): record other fields in EagerContext. 382 383 384class _MirroredReplicaContext(distribute_lib.ReplicaContext): 385 """ReplicaContext for synchronized replica.""" 386 387 def _merge_call(self, fn, args, kwargs): 388 """`merge_call()` implementation for synchronized replica. 389 390 This pauses the current replica thread and passes `fn` and its arguments to 391 the main thread. The main thread will wait until all replicas pause, then 392 invoke `fn` with grouped arguments. The current replica thread will continue 393 after `fn` completes. 394 395 See `_call_for_each_replica` for the logic in the main thread. 396 397 Args: 398 fn: a function that is called in cross replica context with grouped 399 arguments from each replica. `fn` should returns grouped values. 400 args: positional arguments to `fn`. 401 kwargs: keyward arguments to `fn`. 402 403 Returns: 404 Return value of `fn` for the current replica. 405 406 Raises: 407 RuntimeError: when merge_call happens in a different graph, e.g. in a 408 different tf.function, which is not supported now. 409 _RequestedStop: when stop is requested. 410 411 """ 412 t = threading.current_thread() 413 assert isinstance(t, _MirroredReplicaThread) 414 t.merge_fn = fn 415 t.merge_args = args 416 t.merge_kwargs = kwargs 417 t.captured_name_scope = t.graph.get_name_scope() 418 # Adding a "/" at end lets us re-enter this scope later. 419 if t.captured_name_scope: 420 t.captured_name_scope += "/" 421 422 t.captured_var_scope = variable_scope.get_variable_scope() 423 t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access 424 425 # It is problematic if `merge_call` is called under a different graph other 426 # than the one that `_call_for_each_replica` is called under, there are 427 # 3 cases this can happen: 428 # 429 # 1. The `fn` passed to `_call_for_each_replica` is decorated with 430 # `tf.function` and there is a `merge_call` in `fn`. Since 431 # MirroredStrategy traces a separate function per thread (per device), 432 # and each trace takes a shared lock, the lock is never released by the 433 # first thread and subsequent replica threads cannot proceed to trace 434 # their own functions. This issue is addressed by always converting 435 # `_call_for_each_replica(tf.function(f))` to 436 # ``tf.function(_call_for_each_replica(f))`.` in 437 # `MirroredStrategy._call_for_each_replica`. 438 # 439 # 2. The `fn` passed to `_call_for_each_replica` contains a nested 440 # `tf.function`, and there is a `merge_call` in the nested `tf.function`. 441 # In this case each thread can successfully trace its own function, but 442 # since the `merge_fn` passed to `merge_call` is executed in the main 443 # thread (where `_call_for_each_replica` is executed), it can't access 444 # the tensors that come from different graphs. 445 # 446 # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow 447 # statement, and there is a `merge_call` inside the control-flow body, 448 # `fn` or `_call_for_each_replica` is decorated with `tf.function`. 449 # Control flow statement creates a separate graph for its body, similar 450 # to #2, `merge_fn` executed in the main thread can't access the 451 # tensors that come from different graphs. 452 # 453 # We raise an error for #2 and #3. 454 if ops.get_default_graph() != t.graph: 455 raise RuntimeError( 456 "`merge_call` called while defining a new graph or a tf.function." 457 " This can often happen if the function `fn` passed to" 458 " `strategy.run()` contains a nested `@tf.function`, and the nested " 459 "`@tf.function` contains a synchronization point, such as aggregating" 460 " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" 461 " uses a control flow statement which contains a synchronization" 462 " point in the body. Such behaviors are not yet supported. Instead," 463 " please avoid nested `tf.function`s or control flow statements that" 464 " may potentially cross a synchronization boundary, for example," 465 " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" 466 " inside a `tf.function` or move the control flow out of `fn`. If" 467 " you are subclassing a `tf.keras.Model`, please avoid decorating" 468 " overridden methods `test_step` and `train_step` in `tf.function`.") 469 470 t.has_paused.set() 471 t.should_run.wait() 472 t.should_run.clear() 473 if t.coord.should_stop(): 474 raise _RequestedStop() 475 return t.merge_result 476 477 @property 478 def devices(self): 479 distribute_lib.require_replica_context(self) 480 return [ 481 self._strategy.extended.worker_devices_by_replica[ 482 self._replica_id_in_sync_group] 483 ] 484