1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class implementing utilities used by tf.distribute.Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import abc 22 23from tensorflow.python.distribute import tpu_values as tpu_values_lib 24from tensorflow.python.distribute import values as values_lib 25from tensorflow.python.eager import context 26from tensorflow.python.eager import tape 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import variable_scope as vs 32from tensorflow.python.util import nest 33 34 35def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False): 36 """Makes a nest per-replica into a nest of PerReplica/Mirrored values. 37 38 Args: 39 values: Values to regroup 40 wrap_class: Class that `values` be wrapped in. 41 always_wrap: Always wrap the `values` in `wrap_class` even if the values 42 are the same except for DistributeVariable. 43 Returns: 44 Wrapped `values`. 45 """ 46 v0 = values[0] 47 48 if isinstance(v0, list): 49 for v in values[1:]: 50 assert isinstance(v, list) 51 assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % 52 (len(v), len(v0), v, v0)) 53 return [ 54 regroup(tuple(v[i] for v in values), wrap_class, always_wrap) 55 for i in range(len(v0)) 56 ] 57 58 if isinstance(v0, tuple): 59 for v in values[1:]: 60 assert isinstance(v, tuple) 61 assert len(v) == len(v0) 62 regrouped_tuple = tuple( 63 regroup(tuple(v[i] for v in values), wrap_class, always_wrap) 64 for i in range(len(v0))) 65 if hasattr(v0, "_fields"): 66 # This tuple is in fact a namedtuple! Create a new namedtuple instance 67 # and initialize it with the regrouped values: 68 assert hasattr(v0, "_make") 69 return v0._make(regrouped_tuple) 70 else: 71 return regrouped_tuple 72 73 if isinstance(v0, abc.Mapping): 74 v0keys = v0.keys() 75 for v in values[1:]: 76 assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v)) 77 assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % 78 (set(v0keys), set(v.keys()))) 79 # Use the actual type in case it is a class inherited from a dict. 80 return type(v0)({ 81 key: regroup(tuple(v[key] for v in values), 82 wrap_class, always_wrap) 83 for key in v0keys 84 }) 85 86 # If exactly the same object across all devices, return it unwrapped. 87 same_id = True 88 for v in values[1:]: 89 if v is not v0: 90 same_id = False 91 break 92 # Consider three cases where same_id is true: 93 # * If v0 is a DistributedVariable (a MirroredVariable or 94 # SyncOnReadVariable, and same_id means it is the same across all 95 # devices), we want to return it. We check DistributedVariable 96 # specifically since it can look like it has a 97 # _distributed_container member since its members do. 98 if same_id and isinstance(v0, values_lib.DistributedVariable): 99 return v0 100 # * If v0 is a member of a distributed variable, in which case 101 # hasattr(v0, "_distributed_container") is true, we want to 102 # return the DistributedVariable that contains it using the 103 # _distributed_container logic below. This case can trigger 104 # same_id when there is only one device. 105 # * In any other situation, same_id means we return v0 unless `always_wrap` is 106 # true. 107 if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): 108 return v0 109 110 # Detect the case where each device has a parallel component of the 111 # same MirroredVariable (or SyncOnReadVariable). In this case we 112 # want to return the containing MirroredVariable, after a bunch of 113 # sanity checking. In particular, each component should have the 114 # same container, and the devices of the variables should match the 115 # keys of the per-replica dictionary. 116 if hasattr(v0, "_distributed_container"): 117 # pylint: disable=protected-access 118 assert not isinstance(v0, values_lib.MirroredVariable), ( 119 "ids = %s, values = %s" % ([id(v) for v in values], values)) 120 distributed_container = v0._distributed_container() 121 assert distributed_container is not None 122 for v in values[1:]: 123 assert distributed_container is v._distributed_container() 124 return distributed_container 125 # pylint: enable=protected-access 126 127 return wrap_class(values) 128 129 130def select_replica(replica_id, structured): 131 """Specialize a nest of regular & per-replica values for one replica.""" 132 133 def _get(x): 134 # `DistributedValues` would be sliced according to replica unless it is a 135 # `DistributedVariable` because `DistributedVariable` can be handled 136 # directly in the replica context. 137 if (isinstance(x, values_lib.DistributedVariable) or 138 not isinstance(x, values_lib.DistributedValues)): 139 return x 140 else: 141 return x.values[replica_id] 142 143 return nest.map_structure(_get, structured) 144 145 146def select_replica_mirrored(replica_id, structured): 147 """Specialize a nest of regular & mirrored values for one replica.""" 148 assert_mirrored(structured) 149 return select_replica(replica_id, structured) 150 151 152def assert_mirrored(structured): 153 """Raises if the structured is not composed of mirrored or regular values.""" 154 155 def _assert_mirrored(x): 156 if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x): 157 raise TypeError( 158 "Expected value to be mirrored across replicas: %s in %s." % 159 (x, structured)) 160 161 nest.map_structure(_assert_mirrored, structured) 162 163 164def update_regroup(extended, updates, group): 165 """Regroup for an update, with dependencies to ensure all updates execute.""" 166 if not group: 167 regrouped = regroup(updates, values_lib.Mirrored) 168 return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access 169 170 def _make_grouped_mirrored(values): 171 """Convert per-replica list `values` into Mirrored type with grouping.""" 172 if len(values) == 1: 173 return values_lib.Mirrored(values) 174 175 # Make sure we run all updates. Without this, something like 176 # session.run(extended.update(...)) may only update one replica. 177 g = control_flow_ops.group(values) 178 179 # If values is just ops, the grouping is enough. Everything in values 180 # should have the same type, since we expect every replica to be performing 181 # the same computation. 182 if not all(tensor_util.is_tf_type(v) for v in values): 183 return g 184 185 # Otherwise we need tensors with the same values as `values`, but 186 # that have a dependency on `g`. 187 with_dep = [] 188 for v in values: 189 with ops.device(v.device), ops.control_dependencies([g]): 190 with_dep.append(array_ops.identity(v)) 191 192 return values_lib.Mirrored(with_dep) 193 194 return regroup(updates, _make_grouped_mirrored) 195 196 197def value_container(val): 198 """Returns the container that this per-replica `value` belongs to. 199 200 Args: 201 val: A value returned by `call_for_each_replica()` or a variable created in 202 `scope()`. 203 204 Returns: 205 A container that `value` belongs to. 206 If value does not belong to any container (including the case of 207 container having been destroyed), returns the value itself. 208 """ 209 if (hasattr(val, "_distributed_container") and 210 # DistributedVariable has _distributed_container defined 211 # but we don't want to return it. 212 not isinstance(val, values_lib.DistributedVariable)): 213 container = val._distributed_container() # pylint: disable=protected-access 214 if container is not None: 215 return container 216 return val 217 218 219def is_distributed_variable(v): 220 """Determine if a variable is ds variable or TPU mirrored variable.""" 221 return isinstance(v, values_lib.DistributedVariable) 222 223 224def _validate_colocate_extended(v, extended): 225 variable_strategy = v._distribute_strategy # pylint: disable=protected-access 226 if variable_strategy.extended is not extended: 227 raise ValueError( 228 "`colocate_vars_with` must only be passed a variable created in this " 229 "tf.distribute.Strategy.scope(), not %s created in scope: %s" % 230 (v, variable_strategy)) 231 232 233def validate_colocate_distributed_variable(v, extended): 234 if not isinstance(v, values_lib.DistributedVariable): 235 raise ValueError( 236 "`colocate_vars_with` must only be passed a variable created in this " 237 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 238 _validate_colocate_extended(v, extended) 239 240 241def validate_colocate(v, extended): 242 if not hasattr(v, "_distribute_strategy"): 243 raise ValueError( 244 "`colocate_vars_with` must only be passed a variable created in this " 245 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 246 _validate_colocate_extended(v, extended) 247 248 249# Variable creation function for sync strategies. 250def _validate_synchronization(kwargs): 251 """Validate that given synchronization value is valid.""" 252 synchronization = kwargs.get("synchronization", 253 vs.VariableSynchronization.AUTO) 254 if synchronization == vs.VariableSynchronization.NONE: 255 raise ValueError( 256 "`NONE` variable synchronization mode is not supported with " 257 "tf.distribute strategy. Please change the `synchronization` for " 258 "variable: " + str(kwargs["name"])) 259 if synchronization not in (vs.VariableSynchronization.ON_READ, 260 vs.VariableSynchronization.ON_WRITE, 261 vs.VariableSynchronization.AUTO): 262 raise ValueError( 263 "Invalid variable synchronization mode: %s for variable: %s" % 264 (synchronization, kwargs["name"])) 265 if synchronization == vs.VariableSynchronization.AUTO: 266 return vs.VariableSynchronization.ON_WRITE 267 return synchronization 268 269 270def _validate_aggregation(kwargs): 271 aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE) 272 273 if aggregation not in (vs.VariableAggregation.NONE, 274 vs.VariableAggregation.SUM, 275 vs.VariableAggregation.MEAN, 276 vs.VariableAggregation.ONLY_FIRST_REPLICA): 277 raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % 278 (aggregation, kwargs["name"])) 279 return aggregation 280 281 282def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, 283 policy_mapping, **kwargs): 284 """Create distributed variables with given synchronization and aggregation.""" 285 # Figure out what collections this variable should be added to. 286 # We'll add the MirroredVariable to those collections instead. 287 var_collections = kwargs.pop("collections", None) 288 if var_collections is None: 289 var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] 290 kwargs["collections"] = [] 291 292 synchronization = _validate_synchronization(kwargs) 293 # Update synchronization in kwargs in case it's AUTO, which is converted to 294 # ON_WRITE. 295 kwargs["synchronization"] = synchronization 296 aggregation = _validate_aggregation(kwargs) 297 use_var_policy = getattr(strategy.extended, "_use_var_policy", False) 298 299 # Ignore user-specified caching device, not needed for mirrored variables. 300 kwargs.pop("caching_device", None) 301 302 # TODO(josh11b,apassos): It would be better if variable initialization 303 # was never recorded on the tape instead of having to do this manually 304 # here. 305 with tape.stop_recording(): 306 value_list = real_mirrored_creator(**kwargs) 307 if use_var_policy: 308 var_policy_cls = policy_mapping.get(synchronization) 309 var_policy = var_policy_cls(aggregation=aggregation) 310 var_cls = class_mapping.get("VariableClass") 311 result = var_cls(strategy, value_list, aggregation, var_policy=var_policy) 312 else: 313 var_cls = class_mapping.get(synchronization) 314 result = var_cls(strategy, value_list, aggregation) 315 316 # Add the wrapped variable to the requested collections. 317 # The handling of eager mode and the global step matches 318 # ResourceVariable._init_from_args(). 319 if not context.executing_eagerly(): 320 g = ops.get_default_graph() 321 # If "trainable" is True, next_creator() will add the member variables 322 # to the TRAINABLE_VARIABLES collection, so we manually remove 323 # them and replace with the MirroredVariable. We can't set 324 # "trainable" to False for next_creator() since that causes functions 325 # like implicit_gradients to skip those variables. 326 if kwargs.get("trainable", True): 327 var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 328 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 329 for value in value_list: 330 for i, trainable_variable in enumerate(l): 331 if value is trainable_variable: 332 del l[i] 333 break 334 335 g.add_to_collections(var_collections, result) 336 elif ops.GraphKeys.GLOBAL_STEP in var_collections: 337 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) 338 339 return result 340 341 342# Utility functions 343# Return True if the Value is Mirrored or the Variable is replicated and kept in 344# sync. 345def is_mirrored(val): 346 if isinstance(val, values_lib.DistributedVariable): 347 if val._policy: # pylint: disable=protected-access 348 return val._policy._is_mirrored() # pylint: disable=protected-access 349 return isinstance(val, values_lib.Mirrored) 350 351 352def is_sync_on_read(val): 353 if isinstance(val, values_lib.DistributedVariable): 354 if val._policy: # pylint: disable=protected-access 355 return not val._policy._is_mirrored() # pylint: disable=protected-access 356 return not isinstance(val, values_lib.Mirrored) 357 358# The following mapping indicates the policy that you must use for a given 359# variable `synchronization` and `aggregation` pair. 360# OnWritePolicy is used for: 361# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) 362# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) 363# OnReadPolicy is used for: 364# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) 365VARIABLE_POLICY_MAPPING = { 366 vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy, 367 vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy, 368} 369 370VARIABLE_CLASS_MAPPING = { 371 "VariableClass": values_lib.DistributedVariable, 372 vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, 373 vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable, 374} 375 376TPU_VARIABLE_POLICY_MAPPING = { 377 vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy, 378 vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy, 379} 380 381TPU_VARIABLE_CLASS_MAPPING = { 382 "VariableClass": tpu_values_lib.TPUDistributedVariable, 383 vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable, 384 vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable, 385} 386