1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import functools 23import threading 24import weakref 25 26from tensorflow.python import pywrap_tfe 27from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 28from tensorflow.python.autograph.impl import api as autograph 29from tensorflow.python.distribute import distribute_lib 30from tensorflow.python.distribute import distribute_utils 31from tensorflow.python.distribute import shared_variable_creator 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import device as tf_device 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import summary_ops_v2 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.training import coordinator 40 41 42def call_for_each_replica(strategy, fn, args=None, kwargs=None): 43 """Call `fn` on each worker devices(replica). 44 45 It's highly recommended to wrap the call to this function inside a 46 `tf.function`, otherwise the performance is poor. 47 48 Args: 49 strategy: `tf.distribute.Strategy`. 50 fn: function to call on each worker devices. 51 args: positional arguments to `fn`. 52 kwargs: keyword arguments to `fn`. 53 54 Returns: 55 Wrapped returned value of `fn` from all replicas. 56 """ 57 if args is None: 58 args = () 59 if kwargs is None: 60 kwargs = {} 61 62 if isinstance(fn, def_function.Function): 63 if strategy not in _cfer_fn_cache: 64 _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary() 65 wrapped = _cfer_fn_cache[strategy].get(fn) 66 if wrapped is None: 67 # We need to wrap fn such that it triggers _call_for_each_replica inside 68 # the tf.function. We use _clone() instead of @tf.function wrapped 69 # call_for_each_replica() because we would like to retain the arguments to 70 # the @tf.function decorator of fn. 71 wrapped = fn._clone( # pylint: disable=protected-access 72 python_function=functools.partial(call_for_each_replica, strategy, 73 fn.python_function)) 74 _cfer_fn_cache[strategy][fn] = wrapped 75 return wrapped(args, kwargs) 76 77 if context.executing_eagerly(): 78 logging.log_first_n( 79 logging.WARN, "Using %s eagerly has significant " 80 "overhead currently. We will be working on improving " 81 "this in the future, but for now please wrap " 82 "`call_for_each_replica` or `experimental_run` or " 83 "`run` inside a tf.function to get " 84 "the best performance." % strategy.__class__.__name__, 5) 85 else: 86 # When a tf.function is wrapped to trigger _call_for_each_replica (see 87 # the other branch above), AutoGraph stops conversion at 88 # _call_for_each_replica itself (TF library functions are allowlisted). 89 # This makes sure that the Python function that originally passed to 90 # the tf.function is still converted. 91 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 92 93 return _call_for_each_replica(strategy, fn, args, kwargs) 94 95 96# Per strategy cache for call_for_each_replica def_function.Function objects. 97_cfer_fn_cache = weakref.WeakKeyDictionary() 98 99 100@contextlib.contextmanager 101def _enter_graph(g, eager, creator_stack=None): 102 """Context manager for selecting a graph and maybe eager mode.""" 103 if eager: 104 with g.as_default(), context.eager_mode(): 105 if creator_stack is not None: 106 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 107 yield 108 else: 109 with g.as_default(): 110 if creator_stack is not None: 111 g._variable_creator_stack = creator_stack # pylint: disable=protected-access 112 yield 113 114 115def _cpu_device(device): 116 cpu_device = tf_device.DeviceSpec.from_string(device) 117 cpu_device = cpu_device.replace(device_type="CPU", device_index=0) 118 return cpu_device.to_string() 119 120 121class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name 122 pass 123 124 125def _call_for_each_replica(distribution, fn, args, kwargs): 126 """Run `fn` in separate threads, once per replica/worker device. 127 128 Args: 129 distribution: the DistributionStrategy object. 130 fn: function to run (will be run once per replica, each in its own thread). 131 args: positional arguments for `fn` 132 kwargs: keyword arguments for `fn`. 133 134 Returns: 135 Merged return value of `fn` across all replicas. 136 137 Raises: 138 RuntimeError: If fn() calls get_replica_context().merge_call() a different 139 number of times from the available devices. 140 """ 141 # TODO(josh11b): Add this option once we add synchronization to variable 142 # creation. Until then, this is pretty unsafe to use. 143 run_concurrently = False 144 if not context.executing_eagerly(): 145 # Needed for per-thread device, etc. contexts in graph mode. 146 ops.get_default_graph().switch_to_thread_local() 147 148 coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) 149 150 shared_variable_store = {} 151 devices = distribution.extended.worker_devices 152 153 # TODO(isaprykin): Create these threads once instead of during every call. 154 threads = [] 155 for index in range(len(devices)): 156 variable_creator_fn = shared_variable_creator.make_fn( 157 shared_variable_store, index) 158 t = _MirroredReplicaThread( 159 distribution, coord, index, devices, variable_creator_fn, fn, 160 distribute_utils.select_replica(index, args), 161 distribute_utils.select_replica(index, kwargs)) 162 threads.append(t) 163 164 for t in threads: 165 t.start() 166 167 # When `fn` starts `should_run` event is set on _MirroredReplicaThread 168 # (`MRT`) threads. The execution waits until 169 # `MRT.has_paused` is set, which indicates that either `fn` is 170 # complete or a `get_replica_context().merge_call()` is called. If `fn` is 171 # complete, then `MRT.done` is set to True. Otherwise, arguments 172 # of `get_replica_context().merge_call` from all paused threads are grouped 173 # and the `merge_fn` is performed. Results of the 174 # `get_replica_context().merge_call` are then set to `MRT.merge_result`. 175 # Each such `get_replica_context().merge_call` call returns the 176 # `MRT.merge_result` for that thread when `MRT.should_run` event 177 # is reset again. Execution of `fn` resumes. 178 179 try: 180 with coord.stop_on_exception(): 181 all_done = False 182 while not all_done and not coord.should_stop(): 183 done = [] 184 if run_concurrently: 185 for t in threads: 186 t.should_run.set() 187 for t in threads: 188 t.has_paused.wait() 189 t.has_paused.clear() 190 if coord.should_stop(): 191 return None 192 done.append(t.done) 193 else: 194 for t in threads: 195 t.should_run.set() 196 t.has_paused.wait() 197 t.has_paused.clear() 198 if coord.should_stop(): 199 return None 200 done.append(t.done) 201 if coord.should_stop(): 202 return None 203 all_done = all(done) 204 if not all_done: 205 if any(done): 206 raise RuntimeError("Some replicas made a different number of " 207 "replica_context().merge_call() calls.") 208 # get_replica_context().merge_call() case 209 merge_args = distribute_utils.regroup( 210 tuple(t.merge_args for t in threads)) 211 merge_kwargs = distribute_utils.regroup( 212 tuple(t.merge_kwargs for t in threads)) 213 # We capture the name_scope of the MRT when we call merge_fn 214 # to ensure that if we have opened a name scope in the MRT, 215 # it will be respected when executing the merge function. We only 216 # capture the name_scope from the first MRT and assume it is 217 # the same for all other MRTs. 218 mtt_captured_name_scope = threads[0].captured_name_scope 219 mtt_captured_var_scope = threads[0].captured_var_scope 220 # Capture and merge the control dependencies from all the threads. 221 mtt_captured_control_deps = set() 222 for t in threads: 223 mtt_captured_control_deps.update(t.captured_control_deps) 224 with ops.name_scope(mtt_captured_name_scope),\ 225 ops.control_dependencies(mtt_captured_control_deps), \ 226 variable_scope.variable_scope(mtt_captured_var_scope): 227 merge_result = threads[0].merge_fn(distribution, *merge_args, 228 **merge_kwargs) 229 for r, t in enumerate(threads): 230 t.merge_result = distribute_utils.select_replica(r, merge_result) 231 finally: 232 for t in threads: 233 t.should_run.set() 234 coord.join(threads) 235 236 return distribute_utils.regroup(tuple(t.main_result for t in threads)) 237 238 239class _MirroredReplicaThread(threading.Thread): 240 """A thread that runs() a function on a device.""" 241 242 def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, 243 fn, args, kwargs): 244 super(_MirroredReplicaThread, self).__init__() 245 self.coord = coord 246 self.distribution = dist 247 self.devices = devices 248 self.replica_id = replica_id 249 self.replica_id_in_sync_group = ( 250 dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access 251 252 self.variable_creator_fn = variable_creator_fn 253 # State needed to run and return the results of `fn`. 254 self.main_fn = fn 255 self.main_args = args 256 self.main_kwargs = kwargs 257 self.main_result = None 258 self.done = False 259 # State needed to run the next merge_call() (if any) requested via 260 # ReplicaContext. 261 self.merge_fn = None 262 self.merge_args = None 263 self.merge_kwargs = None 264 self.merge_result = None 265 self.captured_name_scope = None 266 self.captured_var_scope = None 267 # We use a thread.Event for the main thread to signal when this 268 # thread should start running (`should_run`), and another for 269 # this thread to transfer control back to the main thread 270 # (`has_paused`, either when it gets to a 271 # `get_replica_context().merge_call` or when `fn` returns). In 272 # either case the event starts cleared, is signaled by calling 273 # set(). The receiving thread waits for the signal by calling 274 # wait() and then immediately clearing the event using clear(). 275 self.should_run = threading.Event() 276 self.has_paused = threading.Event() 277 # These fields have to do with inheriting various contexts from the 278 # parent thread: 279 context.ensure_initialized() 280 ctx = context.context() 281 self.in_eager = ctx.executing_eagerly() 282 self.record_thread_local_summary_state() 283 self.record_thread_local_eager_context_state() 284 self.context_device_policy = ( 285 pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( 286 ctx._context_handle)) # pylint: disable=protected-access 287 self.graph = ops.get_default_graph() 288 with ops.init_scope(): 289 self._init_in_eager = context.executing_eagerly() 290 self._init_graph = ops.get_default_graph() 291 self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access 292 self._var_scope = variable_scope.get_variable_scope() 293 # Adding a "/" at end lets us re-enter this scope later. 294 self._name_scope = self.graph.get_name_scope() 295 if self._name_scope: 296 self._name_scope += "/" 297 if self.replica_id > 0: 298 if not self._name_scope: 299 self._name_scope = "" 300 self._name_scope += "replica_%d/" % self.replica_id 301 302 def run(self): 303 self.should_run.wait() 304 self.should_run.clear() 305 try: 306 if self.coord.should_stop(): 307 return 308 self.restore_thread_local_summary_state() 309 self.restore_thread_local_eager_context_state() 310 # TODO(josh11b): Use current logical device instead of 0 here. 311 with self.coord.stop_on_exception(), \ 312 _enter_graph(self._init_graph, self._init_in_eager), \ 313 _enter_graph(self.graph, self.in_eager, 314 self._variable_creator_stack), \ 315 context.device_policy(self.context_device_policy), \ 316 _MirroredReplicaContext(self.distribution, 317 self.replica_id_in_sync_group), \ 318 ops.device(self.devices[self.replica_id]), \ 319 ops.name_scope(self._name_scope), \ 320 variable_scope.variable_scope( 321 self._var_scope, reuse=self.replica_id > 0), \ 322 variable_scope.variable_creator_scope(self.variable_creator_fn): 323 self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) 324 self.done = True 325 finally: 326 self.has_paused.set() 327 328 def record_thread_local_summary_state(self): 329 """Record the thread local summary state in self.""" 330 # TODO(slebedev): is this still relevant? the referenced bug is closed. 331 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 332 self._summary_step = summary_state.step 333 self._summary_writer = summary_state.writer 334 self._summary_recording = summary_state.is_recording 335 self._summary_recording_distribution_strategy = ( 336 summary_state.is_recording_distribution_strategy) 337 338 def restore_thread_local_summary_state(self): 339 """Restore thread local summary state from self.""" 340 # TODO(slebedev): is this still relevant? the referenced bug is closed. 341 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 342 summary_state.step = self._summary_step 343 summary_state.writer = self._summary_writer 344 summary_state.is_recording = self._summary_recording 345 summary_state.is_recording_distribution_strategy = ( 346 self._summary_recording_distribution_strategy) 347 348 def record_thread_local_eager_context_state(self): 349 ctx = context.context() 350 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 351 self._eager_context_op_callbacks = eager_context_state.op_callbacks 352 # TODO(b/125892694): record other fields in EagerContext. 353 354 def restore_thread_local_eager_context_state(self): 355 ctx = context.context() 356 eager_context_state = ctx._thread_local_data # pylint: disable=protected-access 357 eager_context_state.op_callbacks = self._eager_context_op_callbacks 358 # TODO(b/125892694): record other fields in EagerContext. 359 360 361class _MirroredReplicaContext(distribute_lib.ReplicaContext): 362 """ReplicaContext for synchronized replica.""" 363 364 def _merge_call(self, fn, args, kwargs): 365 """`merge_call()` implementation for synchronized replica. 366 367 This pauses the current replica thread and passes `fn` and its arguments to 368 the main thread. The main thread will wait until all replicas pause, then 369 invoke `fn` with grouped arguments. The current replica thread will continue 370 after `fn` completes. 371 372 See `_call_for_each_replica` for the logic in the main thread. 373 374 Args: 375 fn: a function that is called in cross replica context with grouped 376 arguments from each replica. `fn` should returns grouped values. 377 args: positional arguments to `fn`. 378 kwargs: keyward arguments to `fn`. 379 380 Returns: 381 Return value of `fn` for the current replica. 382 383 Raises: 384 RuntimeError: when merge_call happens in a different graph, e.g. in a 385 different tf.function, which is not supported now. 386 _RequestedStop: when stop is requested. 387 388 """ 389 t = threading.current_thread() 390 assert isinstance(t, _MirroredReplicaThread) 391 t.merge_fn = fn 392 t.merge_args = args 393 t.merge_kwargs = kwargs 394 t.captured_name_scope = t.graph.get_name_scope() 395 # Adding a "/" at end lets us re-enter this scope later. 396 if t.captured_name_scope: 397 t.captured_name_scope += "/" 398 399 t.captured_var_scope = variable_scope.get_variable_scope() 400 t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access 401 402 # It is problematic if `merge_call` is called under a different graph other 403 # than the one that `_call_for_each_replica` is called under, there are 404 # 3 cases this can happen: 405 # 406 # 1. The `fn` passed to `_call_for_each_replica` is decorated with 407 # `tf.function` and there is a `merge_call` in `fn`. Since 408 # MirroredStrategy traces a separate function per thread (per device), 409 # and each trace takes a shared lock, the lock is never released by the 410 # first thread and subsequent replica threads cannot proceed to trace 411 # their own functions. This issue is addressed by always converting 412 # `_call_for_each_replica(tf.function(f))` to 413 # ``tf.function(_call_for_each_replica(f))`.` in 414 # `MirroredStrategy._call_for_each_replica`. 415 # 416 # 2. The `fn` passed to `_call_for_each_replica` contains a nested 417 # `tf.function`, and there is a `merge_call` in the nested `tf.function`. 418 # In this case each thread can successfully trace its own function, but 419 # since the `merge_fn` passed to `merge_call` is executed in the main 420 # thread (where `_call_for_each_replica` is executed), it can't access 421 # the tensors that come from different graphs. 422 # 423 # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow 424 # statement, and there is a `merge_call` inside the control-flow body, 425 # `fn` or `_call_for_each_replica` is decorated with `tf.function`. 426 # Control flow statement creates a separate graph for its body, similar 427 # to #2, `merge_fn` executed in the main thread can't access the 428 # tensors that come from different graphs. 429 # 430 # We raise an error for #2 and #3. 431 if ops.get_default_graph() != t.graph: 432 raise RuntimeError( 433 "`merge_call` called while defining a new graph or a tf.function." 434 " This can often happen if the function `fn` passed to" 435 " `strategy.run()` contains a nested `@tf.function`, and the nested " 436 "`@tf.function` contains a synchronization point, such as aggregating" 437 " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" 438 " uses a control flow statement which contains a synchronization" 439 " point in the body. Such behaviors are not yet supported. Instead," 440 " please avoid nested `tf.function`s or control flow statements that" 441 " may potentially cross a synchronization boundary, for example," 442 " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" 443 " inside a `tf.function` or move the control flow out of `fn`") 444 445 t.has_paused.set() 446 t.should_run.wait() 447 t.should_run.clear() 448 if t.coord.should_stop(): 449 raise _RequestedStop() 450 return t.merge_result 451 452 @property 453 def devices(self): 454 distribute_lib.require_replica_context(self) 455 return [ 456 self._strategy.extended.worker_devices_by_replica[ 457 self._replica_id_in_sync_group] 458 ] 459