1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to get tf.distribute.Strategy related contexts.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.util.lazy_loader import LazyLoader 23from tensorflow.python.util.tf_export import tf_export 24 25 26# There is a circular dependency between this and `distribute` module. So we 27# load it lazily to workaround this. 28distribute_lib = LazyLoader( 29 "distribute_lib", globals(), 30 "tensorflow.python.distribute.distribute_lib") 31 32# ------------------------------------------------------------------------------ 33# Internal API for setting the current thread mode as being either in a 34# replica or cross-replica context for a particular tf.distribute.Strategy. 35 36 37class _ThreadMode(object): 38 39 def __init__(self, dist, cross, replica): 40 self.strategy = dist 41 self.cross_replica_context = cross 42 self.replica_context = replica 43 44 45class _CrossReplicaThreadMode(_ThreadMode): 46 47 def __init__(self, strategy): 48 _ThreadMode.__init__(self, strategy, strategy, None) 49 50 51class _InReplicaThreadMode(_ThreadMode): 52 53 def __init__(self, replica_ctx): 54 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx) 55 56 57def _push_per_thread_mode(context): 58 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access 59 60 61def _pop_per_thread_mode(): 62 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access 63 64 65class _DefaultReplicaThreadMode(_ThreadMode): 66 """Type of default value returned by `_get_per_thread_mode()`. 67 68 Used when the thread-local stack is empty. 69 """ 70 71 def __init__(self): 72 _ThreadMode.__init__(self, _get_default_strategy(), None, 73 _get_default_replica_context()) 74 75 76def _get_per_thread_mode(): 77 try: 78 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access 79 except (AttributeError, IndexError): 80 return _get_default_replica_mode() 81 82 83# ------------------------------------------------------------------------------ 84# Public API for accessing the current thread mode 85 86 87@tf_export("distribute.get_replica_context") 88def get_replica_context(): 89 """Returns the current `tf.distribute.ReplicaContext` or `None`. 90 91 Returns `None` if in a cross-replica context. 92 93 Note that execution: 94 95 1. starts in the default (single-replica) replica context (this function 96 will return the default `ReplicaContext` object); 97 2. switches to cross-replica context (in which case this will return 98 `None`) when entering a `with tf.distribute.Strategy.scope():` block; 99 3. switches to a (non-default) replica context inside 100 `extended.call_for_each_replica(fn, ...)`; 101 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 102 inside `merge_fn` you are back in the cross-replica context (and again 103 this function will return `None`). 104 105 Note that you can also go directly from step 1 to 4 to switch to a 106 cross-replica context for the default `tf.distribute.Strategy`. You may 107 also switch from the cross-replica context of 4 to a replica context by 108 calling `extended.call_for_each_replica()`, jumping back to step 3. 109 110 Most `tf.distribute.Strategy` methods may only be executed in 111 a cross-replica context, in a replica context you should use the 112 `ReplicaContext` API instead. 113 114 Returns: 115 The current `ReplicaContext` object when in a replica context scope, 116 else `None`. 117 118 Within a particular block, exactly one of these two things will be true: 119 120 * `get_replica_context()` returns non-`None`, or 121 * `tf.distribute.is_cross_replica_context()` returns True. 122 """ 123 return _get_per_thread_mode().replica_context 124 125 126def get_cross_replica_context(): 127 """Returns the current tf.distribute.Strategy if in a cross-replica context. 128 129 DEPRECATED: Please use `in_cross_replica_context()` and 130 `get_strategy()` instead. 131 132 Note that execution: 133 134 1. starts in the default (single-replica) replica context; 135 2. switches to cross-replica context when entering a 136 `with tf.distribute.Strategy.scope():` block; 137 3. switches to a (non-default) replica context inside 138 `call_for_each_replica(fn, ...)`; 139 4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then 140 inside `merge_fn` you are back in the cross-replica context. 141 142 Note that you can also go directly from step 1 to 4 to switch to a 143 cross-replica context for the default `tf.distribute.Strategy`. You may 144 also switch from the cross-replica context of 4 to a replica context by 145 calling `call_for_each_replica()`, jumping back to step 3. 146 147 Most `tf.distribute.Strategy` methods may only be executed in 148 a cross-replica context. 149 150 Returns: 151 Returns the current `tf.distribute.Strategy` object in a cross-replica 152 context, or `None`. 153 154 Exactly one of `get_replica_context()` and `get_cross_replica_context()` 155 will return `None` in a particular block. 156 """ 157 return _get_per_thread_mode().cross_replica_context 158 159 160@tf_export("distribute.in_cross_replica_context") 161def in_cross_replica_context(): 162 """Returns True if in a cross-replica context. 163 164 See `tf.distribute.get_replica_context` for details. 165 166 Returns: 167 True if in a cross-replica context (`get_replica_context()` returns 168 `None`), or False if in a replica context (`get_replica_context()` returns 169 non-`None`). 170 """ 171 return _get_per_thread_mode().cross_replica_context is not None 172 173 174@tf_export("distribute.get_strategy") 175def get_strategy(): 176 """Returns the current `tf.distribute.Strategy` object. 177 178 Typically only used in a cross-replica context: 179 180 ``` 181 if tf.distribute.in_cross_replica_context(): 182 strategy = tf.distribute.get_strategy() 183 ... 184 ``` 185 186 Returns: 187 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block, 188 it returns `strategy`, otherwise it returns the default (single-replica) 189 `tf.distribute.Strategy` object. 190 """ 191 return _get_per_thread_mode().strategy 192 193 194@tf_export("distribute.has_strategy") 195def has_strategy(): 196 """Return if there is a current non-default `tf.distribute.Strategy`. 197 198 Returns: 199 True if inside a `with strategy.scope():`. 200 """ 201 return get_strategy() is not _get_default_strategy() 202 203 204def get_strategy_and_replica_context(): 205 per_thread_mode = _get_per_thread_mode() 206 return (per_thread_mode.strategy, per_thread_mode.replica_context) 207 208 209# ------------------------------------------------------------------------------ 210# Defaults that are used when no tf.distribute.Strategy is explicitly created. 211# We create them lazily in a function so that we can workaround the circular 212# dependency on distribute_lib. See lazy loader at the top of this file. 213 214_defaults = { 215 "strategy": None, 216 "replica_context": None, 217 "replica_mode": None 218} 219 220 221def _get_default_strategy(): 222 if _defaults["strategy"] is None: 223 _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy() # pylint: disable=protected-access 224 return _defaults["strategy"] 225 226 227def _get_default_replica_context(): 228 if _defaults["replica_context"] is None: 229 _defaults["replica_context"] = distribute_lib.ReplicaContext( 230 _get_default_strategy(), replica_id_in_sync_group=0) 231 return _defaults["replica_context"] 232 233 234def _get_default_replica_mode(): 235 if _defaults["replica_mode"] is None: 236 _defaults["replica_mode"] = _DefaultReplicaThreadMode() 237 return _defaults["replica_mode"] 238 239 240# Aliases for compatibility with old names. 241get_distribution_strategy = get_strategy 242has_distribution_strategy = has_strategy 243