1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions used by values.py and ps_values.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribute_lib 22from tensorflow.python.distribute import distribution_strategy_context as ds_context 23from tensorflow.python.distribute import reduce_util 24from tensorflow.python.eager import context 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import variable_scope as vs 30from tensorflow.python.saved_model import save_context 31from tensorflow.python.saved_model import save_options 32from tensorflow.python.training.saving import saveable_object 33 34 35def write_object_proto(var, proto, options): 36 """Update a SavedObject proto for the caller. 37 38 If a DistributedVariable object supports this method, it will be called when 39 saving with a pre-built `SavedObject` proto representing the object, plus an 40 instance of `SaveOptions`. This method is then free to modify that proto 41 instance. 42 43 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 44 write out information about their components to the 45 `experimental_distributed_variable_components` field of a 46 `SavedVariable` (depending on the `SaveOptions` variable policy). 47 48 Args: 49 var: The DistributedVariable object. 50 proto: A pre-built `SavedObject` proto for this object. It is assumed this 51 will be a `SavedVariable` instance. 52 options: A `SaveOptions` instance. 53 """ 54 if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access 55 ): 56 for var in var.values: 57 var_proto = ( 58 proto.variable.experimental_distributed_variable_components.add()) 59 var_proto.name = var.name.split(":")[0] 60 var_proto.device = var.device 61 62 63def get_on_write_saveable(var, primary_var, name): 64 """Return saveable spec for AUTO and ON_WRITE variables.""" 65 # We use a callable so that we don't have to evaluate this expression 66 # in the case where we are trying to restore instead of save. 67 def tensor(): 68 if context.executing_eagerly() and not primary_var.is_initialized(): 69 # A SaveSpec tensor value of `None` indicates that the variable is 70 # uninitialized. 71 return None 72 strategy = var.distribute_strategy 73 return strategy.extended.read_var(var) 74 75 spec = saveable_object.SaveSpec( 76 tensor=tensor, 77 slice_spec="", 78 name=name, 79 dtype=var.dtype, 80 device=primary_var.device) 81 82 return tensor, [spec] 83 84 85def get_on_write_restore_ops(var, tensor): 86 """Return restore ops for AUTO and ON_WRITE variables.""" 87 packed_var = var._packed_variable # pylint: disable=protected-access 88 if packed_var is not None: 89 return control_flow_ops.group( 90 tuple( 91 assign_on_device(d, packed_var, tensor) 92 for d in packed_var.devices)) 93 return control_flow_ops.group( 94 tuple( 95 assign_on_device(v.device, v, tensor) 96 for v in var.values)) 97 98 99def get_on_read_saveable(var, primary_var, name): 100 """Return saveables for ON_READ variable.""" 101 102 # We use a callable so that we don't have to evaluate this expression 103 # in the case where we are trying to restore instead of save. 104 def tensor(): 105 return var._get_cross_replica() # pylint: disable=protected-access 106 107 spec = saveable_object.SaveSpec( 108 tensor=tensor, 109 slice_spec="", 110 name=name, 111 dtype=var.dtype, 112 device=primary_var.device) 113 114 return tensor, [spec] 115 116 117def get_on_read_restore_ops(var, tensor, aggregation): 118 """Return restore ops for ON_READ variables.""" 119 # To preserve the sum across save and restore, we have to divide the 120 # total across all devices when restoring a variable that was summed 121 # when saving. 122 if aggregation == vs.VariableAggregation.SUM: 123 strategy = var.distribute_strategy 124 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, 125 var.dtype) 126 return control_flow_ops.group( 127 tuple( 128 assign_on_device(v.device, v, tensor) 129 for v in var.values)) 130 131 132# Utility function that indicates if you are in an UpdateContext when running 133# in a replica fn. 134def in_replica_update_context(): 135 return distribute_lib.get_update_replica_id() is not None 136 137 138def on_write_assign(var, value, use_locking=False, name=None, read_value=True): 139 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 140 return var._update( # pylint: disable=protected-access 141 update_fn=assign_fn, 142 value=value, 143 use_locking=use_locking, 144 name=name, 145 read_value=read_value) 146 147 148def on_write_assign_add(var, value, use_locking=False, name=None, 149 read_value=True): 150 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 151 return var._update( # pylint: disable=protected-access 152 update_fn=assign_add_fn, 153 value=value, 154 use_locking=use_locking, 155 name=name, 156 read_value=read_value) 157 158 159def on_write_assign_sub(var, value, use_locking=False, name=None, 160 read_value=True): 161 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 162 return var._update( # pylint: disable=protected-access 163 update_fn=assign_sub_fn, 164 value=value, 165 use_locking=use_locking, 166 name=name, 167 read_value=read_value) 168 169 170def assign_on_each_device(var, assign_func, value, read_value): 171 """Update the variable on each replica with the given assign_func and value.""" 172 if var._packed_variable is not None: # pylint: disable=protected-access 173 update = control_flow_ops.group( 174 tuple( 175 assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access 176 else: 177 update = control_flow_ops.group( 178 tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access 179 if not read_value: 180 return update 181 with ops.control_dependencies([update] if update else []): 182 return var.read_value() 183 184 185def on_read_assign_sub_cross_replica(var, value, read_value=True): 186 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 187 if ds_context.in_cross_replica_context(): 188 if var.aggregation == vs.VariableAggregation.SUM: 189 raise ValueError( 190 "SyncOnReadVariable does not support `assign_sub` in " 191 "cross-replica context when aggregation is set to " 192 "`tf.VariableAggregation.SUM`.") 193 return assign_on_each_device(var, assign_sub_on_device, 194 value, read_value) 195 196 197def on_read_assign_add_cross_replica(var, value, read_value=True): 198 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 199 if ds_context.in_cross_replica_context(): 200 if var.aggregation == vs.VariableAggregation.SUM: 201 raise ValueError( 202 "SyncOnReadVariable does not support `assign_add` in " 203 "cross-replica context when aggregation is set to " 204 "`tf.VariableAggregation.SUM`.") 205 return assign_on_each_device(var, assign_add_on_device, 206 value, read_value) 207 208 209def on_read_assign_cross_replica(var, value, read_value=True): 210 """Return the value of the variable in cross replica context.""" 211 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 212 if ds_context.in_cross_replica_context(): 213 # To preserve the sum across save and restore, we have to divide the 214 # total across all devices when restoring a variable that was summed 215 # when saving. 216 tensor = value 217 if var.aggregation == vs.VariableAggregation.SUM: 218 strategy = var._distribute_strategy # pylint: disable=protected-access 219 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, 220 var.dtype) 221 return assign_on_each_device(var, assign_on_device, tensor, 222 read_value) 223 224 225def scatter_sub(var, sparse_delta, use_locking=False, name=None): 226 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) 227 return var._update( # pylint: disable=protected-access 228 update_fn=scatter_sub_fn, 229 value=sparse_delta, 230 use_locking=use_locking, 231 name=name) 232 233 234def scatter_add(var, sparse_delta, use_locking=False, name=None): 235 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) 236 return var._update( # pylint: disable=protected-access 237 update_fn=scatter_add_fn, 238 value=sparse_delta, 239 use_locking=use_locking, 240 name=name) 241 242 243def scatter_mul(var, sparse_delta, use_locking=False, name=None): 244 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) 245 return var._update( # pylint: disable=protected-access 246 update_fn=scatter_mul_fn, 247 value=sparse_delta, 248 use_locking=use_locking, 249 name=name) 250 251 252def scatter_div(var, sparse_delta, use_locking=False, name=None): 253 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) 254 return var._update( # pylint: disable=protected-access 255 update_fn=scatter_div_fn, 256 value=sparse_delta, 257 use_locking=use_locking, 258 name=name) 259 260 261def scatter_min(var, sparse_delta, use_locking=False, name=None): 262 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) 263 return var._update( # pylint: disable=protected-access 264 update_fn=scatter_min_fn, 265 value=sparse_delta, 266 use_locking=use_locking, 267 name=name) 268 269 270def scatter_max(var, sparse_delta, use_locking=False, name=None): 271 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) 272 return var._update( # pylint: disable=protected-access 273 update_fn=scatter_max_fn, 274 value=sparse_delta, 275 use_locking=use_locking, 276 name=name) 277 278 279def scatter_update(var, sparse_delta, use_locking=False, name=None): 280 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) 281 return var._update( # pylint: disable=protected-access 282 update_fn=scatter_update_fn, 283 value=sparse_delta, 284 use_locking=use_locking, 285 name=name) 286 287 288def get_current_replica_id_as_int(): 289 """Returns the current replica ID as an integer, or `None`.""" 290 replica_context = ds_context.get_replica_context() 291 if replica_context: 292 replica_id = replica_context._replica_id # pylint: disable=protected-access 293 if not isinstance(replica_id, int): 294 replica_id = tensor_util.constant_value(replica_id) 295 else: 296 replica_id = distribute_lib.get_update_replica_id() 297 return replica_id 298 299 300def assign_on_device(device, variable, tensor): 301 with ops.device(device): 302 return variable.assign(tensor) 303 304 305def assign_add_on_device(device, variable, tensor): 306 with ops.device(device): 307 return variable.assign_add(tensor) 308 309 310def assign_sub_on_device(device, variable, tensor): 311 with ops.device(device): 312 return variable.assign_sub(tensor) 313 314 315def assert_replica_context(strategy): 316 replica_context = ds_context.get_replica_context() 317 if not replica_context: 318 raise RuntimeError( 319 "Replica-local variables may only be assigned in a replica context.") 320 if replica_context.strategy is not strategy: 321 raise RuntimeError( 322 "Replica-local variables may only be assigned in a replica context.") 323 324 325def apply_aggregation(strategy, value, aggregation, destinations): 326 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 327 return strategy.extended.broadcast_to( 328 strategy.experimental_local_results(value)[0], 329 destinations=destinations) 330 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 331 return strategy.extended.reduce_to(reduce_op, value, destinations) 332 333 334aggregation_error_msg = ( 335 "You must specify an aggregation method to update a " 336 "{variable_type} in Replica Context. You can do so by passing " 337 "an explicit value for argument `aggregation` to tf.Variable(..)." 338 "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`" 339 "`tf.VariableAggregation` lists the possible aggregation methods." 340 "This is required because {variable_type} should always be " 341 "kept in sync. When updating them or assigning to them in a " 342 "replica context, we automatically try to aggregate the values " 343 "before updating the variable. For this aggregation, we need to " 344 "know the aggregation method. " 345 "Another alternative is to not try to update such " 346 "{variable_type} in replica context, but in cross replica " 347 "context. You can enter cross replica context by calling " 348 "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." 349 "Inside `merge_fn`, you can then update the {variable_type} " 350 "using `tf.distribute.StrategyExtended.update()`.") 351 352 353scatter_error_msg = ("{op_name} is only supported for mirrored " 354 "variable (variable created within certain " 355 "`tf.distribute.Strategy` scope) with NONE or " 356 "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.") 357 358 359def is_saving_non_distributed(): 360 """Returns whether we're saving a non-distributed version of the model. 361 362 It returns True iff we are in saving context and are saving a non-distributed 363 version of the model. That is, SaveOptions.experimental_variable_policy is 364 NONE. 365 366 Returns: 367 A boolean. 368 """ 369 if not save_context.in_save_context(): 370 return False 371 options = save_context.get_save_options() 372 return (options.experimental_variable_policy != 373 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES) 374 375 376def mark_as_unsaveable(): 377 """Marks the function as unsaveable if not inside save context.""" 378 if ops.inside_function() and not save_context.in_save_context(): 379 ops.get_default_graph().mark_as_unsaveable(""" 380ConcreteFunction that uses distributed variables in certain way cannot be saved. 381If you're saving with 382 383tf.saved_model.save(..., signatures=f.get_concrete_function()) 384 385do 386 387@tf.function(input_signature=...) 388def f_with_input_signature(): 389 ... 390 391tf.saved_model.save(..., signatures=f_with_input_signature)` 392 393instead.""") 394