1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to get tf.distribute.Strategy related contexts.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import threading 23 24from tensorflow.python import tf2 25from tensorflow.python.framework import ops 26from tensorflow.python.util.lazy_loader import LazyLoader 27from tensorflow.python.util.tf_export import tf_export 28 29 30# There is a circular dependency between this and the `distribute_lib` module. 31# So we load it lazily to work around this. 32distribute_lib = LazyLoader( 33 "distribute_lib", globals(), 34 "tensorflow.python.distribute.distribute_lib") 35 36# ------------------------------------------------------------------------------ 37# Internal API for setting the current thread mode as being either in a 38# replica or cross-replica context for a particular tf.distribute.Strategy. 39 40 41class _ThreadMode(object): 42 43 def __init__(self, dist, cross, replica): 44 self.strategy = dist 45 self.cross_replica_context = cross 46 self.replica_context = replica 47 48 49class _CrossReplicaThreadMode(_ThreadMode): 50 51 def __init__(self, strategy): 52 _ThreadMode.__init__(self, strategy, strategy, None) 53 54 55class _InReplicaThreadMode(_ThreadMode): 56 57 def __init__(self, replica_ctx): 58 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx) 59 60 61def _push_per_thread_mode(context): 62 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access 63 64 65def _pop_per_thread_mode(): 66 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access 67 68 69class _DefaultReplicaThreadMode(_ThreadMode): 70 """Type of default value returned by `_get_per_thread_mode()`. 71 72 Used when the thread-local stack is empty. 73 """ 74 75 def __init__(self): 76 _ThreadMode.__init__(self, _get_default_strategy(), None, 77 _get_default_replica_context()) 78 79 80def _get_per_thread_mode(): 81 try: 82 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access 83 except (AttributeError, IndexError): 84 return _get_default_replica_mode() 85 86 87# ------------------------------------------------------------------------------ 88# Public API for accessing the current thread mode 89 90 91@tf_export("distribute.get_replica_context") 92def get_replica_context(): 93 """Returns the current `tf.distribute.ReplicaContext` or `None`. 94 95 Returns `None` if in a cross-replica context. 96 97 Note that execution: 98 99 1. starts in the default (single-replica) replica context (this function 100 will return the default `ReplicaContext` object); 101 2. switches to cross-replica context (in which case this will return 102 `None`) when entering a `with tf.distribute.Strategy.scope():` block; 103 3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`; 104 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 105 inside `merge_fn` you are back in the cross-replica context (and again 106 this function will return `None`). 107 108 Most `tf.distribute.Strategy` methods may only be executed in 109 a cross-replica context, in a replica context you should use the 110 API of the `tf.distribute.ReplicaContext` object returned by this 111 method instead. 112 113 ``` 114 assert tf.distribute.get_replica_context() is not None # default 115 with strategy.scope(): 116 assert tf.distribute.get_replica_context() is None 117 118 def f(): 119 replica_context = tf.distribute.get_replica_context() # for strategy 120 assert replica_context is not None 121 tf.print("Replica id: ", replica_context.replica_id_in_sync_group, 122 " of ", replica_context.num_replicas_in_sync) 123 124 strategy.run(f) 125 ``` 126 127 Returns: 128 The current `tf.distribute.ReplicaContext` object when in a replica context 129 scope, else `None`. 130 131 Within a particular block, exactly one of these two things will be true: 132 133 * `get_replica_context()` returns non-`None`, or 134 * `tf.distribute.is_cross_replica_context()` returns True. 135 """ 136 return _get_per_thread_mode().replica_context 137 138 139def get_cross_replica_context(): 140 """Returns the current tf.distribute.Strategy if in a cross-replica context. 141 142 DEPRECATED: Please use `in_cross_replica_context()` and 143 `get_strategy()` instead. 144 145 Returns: 146 Returns the current `tf.distribute.Strategy` object in a cross-replica 147 context, or `None`. 148 149 Exactly one of `get_replica_context()` and `get_cross_replica_context()` 150 will return `None` in a particular block. 151 """ 152 return _get_per_thread_mode().cross_replica_context 153 154 155@tf_export("distribute.in_cross_replica_context") 156def in_cross_replica_context(): 157 """Returns `True` if in a cross-replica context. 158 159 See `tf.distribute.get_replica_context` for details. 160 161 ``` 162 assert not tf.distribute.in_cross_replica_context() 163 with strategy.scope(): 164 assert tf.distribute.in_cross_replica_context() 165 166 def f(): 167 assert not tf.distribute.in_cross_replica_context() 168 169 strategy.run(f) 170 ``` 171 172 Returns: 173 `True` if in a cross-replica context (`get_replica_context()` returns 174 `None`), or `False` if in a replica context (`get_replica_context()` returns 175 non-`None`). 176 """ 177 return _get_per_thread_mode().cross_replica_context is not None 178 179 180@tf_export("distribute.get_strategy") 181def get_strategy(): 182 """Returns the current `tf.distribute.Strategy` object. 183 184 Typically only used in a cross-replica context: 185 186 ``` 187 if tf.distribute.in_cross_replica_context(): 188 strategy = tf.distribute.get_strategy() 189 ... 190 ``` 191 192 Returns: 193 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block, 194 it returns `strategy`, otherwise it returns the default (single-replica) 195 `tf.distribute.Strategy` object. 196 """ 197 return _get_per_thread_mode().strategy 198 199 200@tf_export("distribute.has_strategy") 201def has_strategy(): 202 """Return if there is a current non-default `tf.distribute.Strategy`. 203 204 ``` 205 assert not tf.distribute.has_strategy() 206 with strategy.scope(): 207 assert tf.distribute.has_strategy() 208 ``` 209 210 Returns: 211 True if inside a `with strategy.scope():`. 212 """ 213 return get_strategy() is not _get_default_strategy() 214 215 216def get_strategy_and_replica_context(): 217 per_thread_mode = _get_per_thread_mode() 218 return (per_thread_mode.strategy, per_thread_mode.replica_context) 219 220 221@tf_export("distribute.experimental_set_strategy") 222def experimental_set_strategy(strategy): 223 """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`. 224 225 ``` 226 tf.distribute.experimental_set_strategy(strategy1) 227 f() 228 tf.distribute.experimental_set_strategy(strategy2) 229 g() 230 tf.distribute.experimental_set_strategy(None) 231 h() 232 ``` 233 234 is equivalent to: 235 236 ``` 237 with strategy1.scope(): 238 f() 239 with strategy2.scope(): 240 g() 241 h() 242 ``` 243 244 In general, you should use the `with strategy.scope():` API, but this 245 alternative may be convenient in notebooks where you would have to put 246 each cell in a `with strategy.scope():` block. 247 248 Note: This should only be called outside of any TensorFlow scope to 249 avoid improper nesting. 250 251 Args: 252 strategy: A `tf.distribute.Strategy` object or None. 253 254 Raises: 255 RuntimeError: If called inside a `with strategy.scope():`. 256 """ 257 old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access 258 if old_scope is not None: 259 old_scope.__exit__(None, None, None) 260 ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access 261 if has_strategy(): 262 raise RuntimeError( 263 "Must not be called inside a `tf.distribute.Strategy` scope.") 264 if strategy is not None: 265 new_scope = strategy.scope() 266 new_scope.__enter__() 267 ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access 268 269 270# ------------------------------------------------------------------------------ 271# Internal helpers. 272 273 274@contextlib.contextmanager 275def enter_or_assert_strategy(strategy): 276 if not has_strategy(): 277 with strategy.scope(): 278 yield 279 else: 280 _assert_strategy(strategy) 281 yield 282 283 284# ------------------------------------------------------------------------------ 285# Defaults that are used when no tf.distribute.Strategy is explicitly created. 286# We create them lazily in a function so that we can workaround the circular 287# dependency on distribute_lib. See lazy loader at the top of this file. 288 289_defaults = { 290 "strategy": None, 291 "replica_context": None, 292 "replica_mode": None 293} 294# Note: These need to be different locks since _get_default_replica_context 295# calls _get_default_strategy inside its lock, and them using the same lock 296# can lead to deadlock. 297_default_strategy_lock = threading.Lock() 298_default_replica_context_lock = threading.Lock() 299_default_replica_mode_lock = threading.Lock() 300 301 302def _assert_strategy(strategy): 303 if not has_strategy(): 304 raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % 305 (strategy,)) 306 current_strategy = get_strategy() 307 if current_strategy is not strategy: 308 raise RuntimeError( 309 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 310 (current_strategy, strategy)) 311 312 313def _get_default_strategy(): 314 if _defaults["strategy"] is None: 315 # Avoid race condition causing two defaults to be created 316 with _default_strategy_lock: 317 if _defaults["strategy"] is None: 318 # pylint: disable=protected-access 319 # Make sure distribute_lib module is loaded by accessing some member. 320 _ = distribute_lib._creating_default_strategy_singleton 321 distribute_lib._creating_default_strategy_singleton = True 322 if tf2.enabled(): 323 _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy() 324 else: 325 _defaults["strategy"] = ( 326 distribute_lib._DefaultDistributionStrategyV1()) 327 distribute_lib._creating_default_strategy_singleton = False 328 # pylint: enable=protected-access 329 return _defaults["strategy"] 330 331 332def _get_default_replica_context(): 333 if _defaults["replica_context"] is None: 334 # Avoid race condition causing two defaults to be created 335 with _default_replica_context_lock: 336 if _defaults["replica_context"] is None: 337 # pylint: disable=protected-access 338 _defaults["replica_context"] = distribute_lib._DefaultReplicaContext( 339 _get_default_strategy(), replica_id_in_sync_group=0) 340 # pylint: enable=protected-access 341 return _defaults["replica_context"] 342 343 344def _get_default_replica_mode(): 345 if _defaults["replica_mode"] is None: 346 # Avoid race condition causing two defaults to be created 347 with _default_replica_mode_lock: 348 if _defaults["replica_mode"] is None: 349 _defaults["replica_mode"] = _DefaultReplicaThreadMode() 350 return _defaults["replica_mode"] 351 352 353# Aliases for compatibility with old names. 354get_distribution_strategy = get_strategy 355has_distribution_strategy = has_strategy 356