1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import weakref 23 24from tensorflow.python.distribute import device_util 25from tensorflow.python.distribute import distribute_lib 26from tensorflow.python.distribute import distribution_strategy_context as ds_context 27from tensorflow.python.distribute import packed_distributed_variable as packed 28from tensorflow.python.distribute import reduce_util 29from tensorflow.python.distribute import values_util 30from tensorflow.python.eager import context 31from tensorflow.python.framework import composite_tensor 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import resource_variable_ops 39from tensorflow.python.ops import variable_scope as vs 40from tensorflow.python.ops import variables as variables_lib 41from tensorflow.python.saved_model import save_context 42from tensorflow.python.training.saving import saveable_object 43from tensorflow.python.training.tracking import base as trackable 44from tensorflow.python.types import core 45from tensorflow.python.util.tf_export import tf_export 46 47 48def _on_write_update_replica(var, update_fn, value, **kwargs): 49 """Updates variables with ON_WRITE synchronization in replica context.""" 50 if var.aggregation == vs.VariableAggregation.NONE: 51 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 52 53 if not ds_context.get_strategy().extended._use_merge_call(): # pylint: disable=protected-access 54 # Don't allow MEAN with non float dtype, since it may cause unexpected 55 # precision loss. Python3 and NumPy automatically upcast integers to 56 # float in division, but we should always preserve the type. 57 if var.aggregation == vs.VariableAggregation.MEAN and ( 58 not var.dtype.is_floating) and tensor_util.is_tf_type(value): 59 raise ValueError( 60 "Cannot update non-float variables with " 61 "tf.VariableAggregation.MEAN aggregation in replica context. " 62 "Either change the variable dtype to float or update it in " 63 "cross-replica context.") 64 65 aggregated_value = apply_aggregation_replica_context( 66 value, var.aggregation, var) 67 values_util.mark_as_unsaveable() 68 69 return ds_context.get_replica_context()._update( # pylint: disable=protected-access 70 var, 71 update_fn, 72 args=(aggregated_value,), 73 kwargs=kwargs, 74 group=True) 75 76 else: 77 78 def merge_fn(strategy, value, **kwargs): 79 """Aggregate values and update all variables in cross replica context.""" 80 # Don't allow MEAN with non float dtype, since it may cause unexpected 81 # precision loss. Python3 and NumPy automatically upcast integers to 82 # float in division, but we should always preserve the type. 83 # 84 # Note that to be backward compatible we allow the case when the value 85 # is *always* the same on each replica. I.E. value is not a 86 # PerReplica. Refer to regroup() to see how values are grouped. 87 if var.aggregation == vs.VariableAggregation.MEAN and ( 88 not var.dtype.is_floating) and isinstance(value, PerReplica): 89 raise ValueError( 90 "Cannot update non-float variables with " 91 "tf.VariableAggregation.MEAN aggregation in replica context. " 92 "Either change the variable dtype to float or update it in " 93 "cross-replica context.") 94 95 assert strategy == var.distribute_strategy 96 v = values_util.apply_aggregation(strategy, value, var.aggregation, var) 97 return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access 98 99 return ds_context.get_replica_context().merge_call( 100 merge_fn, args=(value,), kwargs=kwargs) 101 102 103def apply_aggregation_replica_context(value, aggregation, destinations): 104 """Aggregate `value` to `destinations` as specified by `aggregation`.""" 105 # if it is a python literal, return without aggregation 106 if isinstance(value, DistributedValues): 107 raise TypeError( 108 "Cannot use DistributedValues to update variables in replica context.") 109 if not tensor_util.is_tf_type(value): 110 return value 111 112 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 113 # Switch to cross-replica context to broadcast 114 def merge_fn(strategy, value): 115 return strategy.extended.broadcast_to( 116 strategy.experimental_local_results(value)[0], 117 destinations=destinations) 118 119 return ds_context.get_replica_context().merge_call(merge_fn, args=(value,)) 120 121 else: 122 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 123 aggregated_value = ds_context.get_strategy( # pylint: disable=protected-access 124 ).extended._replica_ctx_all_reduce(reduce_op, value) 125 return aggregated_value 126 127 128@tf_export("distribute.DistributedValues", v1=[]) 129class DistributedValues(object): 130 """Base class for representing distributed values. 131 132 A subclass instance of `tf.distribute.DistributedValues` is created when 133 creating variables within a distribution strategy, iterating a 134 `tf.distribute.DistributedDataset` or through `tf.distribute.Strategy.run`. 135 This base class should never be instantiated directly. 136 `tf.distribute.DistributedValues` contains a value per replica. Depending on 137 the subclass, the values could either be synced on update, synced on demand, 138 or never synced. 139 140 `tf.distribute.DistributedValues` can be reduced to obtain single value across 141 replicas, as input into `tf.distribute.Strategy.run` or the per-replica values 142 inspected using `tf.distribute.Strategy.experimental_local_results`. 143 144 Example usage: 145 146 1. Created from a `tf.distribute.DistributedDataset`: 147 148 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 149 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 150 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 151 >>> distributed_values = next(dataset_iterator) 152 153 2. Returned by `run`: 154 155 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 156 >>> @tf.function 157 ... def run(): 158 ... ctx = tf.distribute.get_replica_context() 159 ... return ctx.replica_id_in_sync_group 160 >>> distributed_values = strategy.run(run) 161 162 3. As input into `run`: 163 164 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 165 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 166 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 167 >>> distributed_values = next(dataset_iterator) 168 >>> @tf.function 169 ... def run(input): 170 ... return input + 1.0 171 >>> updated_value = strategy.run(run, args=(distributed_values,)) 172 173 4. Reduce value: 174 175 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 176 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 177 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 178 >>> distributed_values = next(dataset_iterator) 179 >>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM, 180 ... distributed_values, 181 ... axis = 0) 182 183 5. Inspect local replica values: 184 185 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 186 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 187 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 188 >>> per_replica_values = strategy.experimental_local_results( 189 ... distributed_values) 190 >>> per_replica_values 191 (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>, 192 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>) 193 194 """ 195 196 def __init__(self, values): 197 """Should only be called by subclass __init__.""" 198 self._values = tuple(values) 199 200 def _get(self): 201 """Returns the value for the current device or raises a ValueError.""" 202 replica_id = values_util.get_current_replica_id_as_int() 203 if replica_id is None: 204 return self._get_cross_replica() 205 else: 206 return self._values[replica_id] 207 208 def _get_cross_replica(self): 209 raise NotImplementedError( 210 "DistributedValues._get_cross_replica should be implemented by " 211 "sub-classes which support cross-replica accesses.") 212 213 def _get_on_device_or_primary(self): 214 """Returns value in same replica or device if possible, else the _primary.""" 215 replica_id = values_util.get_current_replica_id_as_int() 216 if replica_id is None: 217 # Try to find a value on the current device. 218 current_device = device_util.canonicalize(device_util.current()) 219 for value in self._values: 220 if device_util.canonicalize(value.device) == current_device: 221 return value 222 return self._primary 223 else: 224 return self._values[replica_id] 225 226 @property 227 def _primary(self): 228 """Returns a representative component.""" 229 return self._values[0] 230 231 @property 232 def _devices(self): 233 return tuple(v.device for v in self._values) 234 235 def __str__(self): 236 debug_str = ",\n".join( 237 " %d: %s" % (i, v) for i, v in enumerate(self._values)) 238 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 239 240 def __repr__(self): 241 debug_repr = ",\n".join( 242 " %d: %r" % (i, v) for i, v in enumerate(self._values)) 243 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 244 245 246# NOTE(josh11b,apassos): It would be great if we could inspect the values this was 247# initialized with and use that to generate the overloaded operators here. 248# Unfortunately, Python's rules for special methods don't allow this, see 249# https://docs.python.org/3/reference/datamodel.html#special-method-names 250# "if a class defines a method named __getitem__(), and x is an instance of 251# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)." 252# In particular, these special methods don't go through __getattr__, and 253# it will only use those methods if they are defined in the class, not the 254# object. 255class DistributedDelegate(DistributedValues): 256 """A map from device to values; acts as the same type as the values.""" 257 258 def __getattr__(self, name): 259 # The '_use_resource_variables' and the attrs starts with '_self' are used 260 # for restoring the saved_model proto, and '_attribute_sentinel' is used for 261 # Layer tracking. At the point these attrs are queried, the variable has not 262 # been initialized. Thus it should not query those of the underlying 263 # components. 264 if name.startswith("_self_") or name in ("_use_resource_variables", 265 "_attribute_sentinel", 266 "_distributed_container"): 267 return super(DistributedDelegate, self).__getattr__(name) 268 269 # This allows copy.copy(DistributedDelegate). When copying an object, 270 # copy.copy doesn't invoke its __init__ method, instead it makes a new 271 # empty object, then copies the attributes over. copy.copy looks for 272 # attributes like "__getstate__" in case the object implements its custom 273 # copying. Since DistributedDelegate doesn't have those attributes defined, 274 # __getattr__ will be invoked, which tries to access "_values" attributes, 275 # but that doesn't exist either because this is an empty object, and again 276 # __getattr__ is invoked, leading to an infinite recursion. 277 if name == "_values": 278 raise AttributeError() 279 280 # TODO(priyag): This needs to be made robust against pitfalls from mix use 281 # __getattr__ and @property. See b/120402273. 282 return getattr(self._get(), name) 283 284 @property 285 def values(self): 286 """Returns the per replica values.""" 287 return self._values 288 289 def _get_as_operand(self): 290 """Returns the value for operations for the current device. 291 292 Some implementations, e.g. `TPUMirroredVariable`, are not able to return the 293 value type within a replica context. They can, however, return a value that 294 can be used by the operations below. 295 """ 296 return self._get() 297 298 # pylint: disable=multiple-statements 299 def __add__(self, o): 300 return self._get_as_operand() + o 301 302 def __radd__(self, o): 303 return o + self._get_as_operand() 304 305 def __sub__(self, o): 306 return self._get_as_operand() - o 307 308 def __rsub__(self, o): 309 return o - self._get_as_operand() 310 311 def __mul__(self, o): 312 return self._get_as_operand() * o 313 314 def __rmul__(self, o): 315 return o * self._get_as_operand() 316 317 def __truediv__(self, o): 318 return self._get_as_operand() / o 319 320 def __rtruediv__(self, o): 321 return o / self._get_as_operand() 322 323 def __floordiv__(self, o): 324 return self._get_as_operand() // o 325 326 def __rfloordiv__(self, o): 327 return o // self._get_as_operand() 328 329 def __mod__(self, o): 330 return self._get_as_operand() % o 331 332 def __rmod__(self, o): 333 return o % self._get_as_operand() 334 335 def __lt__(self, o): 336 return self._get_as_operand() < o 337 338 def __le__(self, o): 339 return self._get_as_operand() <= o 340 341 def __gt__(self, o): 342 return self._get_as_operand() > o 343 344 def __ge__(self, o): 345 return self._get_as_operand() >= o 346 347 def __and__(self, o): 348 return self._get_as_operand() & o 349 350 def __rand__(self, o): 351 return o & self._get_as_operand() 352 353 def __or__(self, o): 354 return self._get_as_operand() | o 355 356 def __ror__(self, o): 357 return o | self._get_as_operand() 358 359 def __xor__(self, o): 360 return self._get_as_operand() ^ o 361 362 def __rxor__(self, o): 363 return o ^ self._get_as_operand() 364 365 def __getitem__(self, o): 366 return self._get_as_operand()[o] 367 368 def __pow__(self, o, modulo=None): 369 return pow(self._get_as_operand(), o, modulo) 370 371 def __rpow__(self, o): 372 return pow(o, self._get_as_operand()) 373 374 def __invert__(self): 375 return ~self._get_as_operand() 376 377 def __neg__(self): 378 return -self._get_as_operand() 379 380 def __abs__(self): 381 return abs(self._get_as_operand()) 382 383 def __div__(self, o): 384 try: 385 return self._get_as_operand().__div__(o) 386 except AttributeError: 387 # See https://docs.python.org/3/library/constants.html#NotImplemented 388 return NotImplemented 389 390 def __rdiv__(self, o): 391 try: 392 return self._get_as_operand().__rdiv__(o) 393 except AttributeError: 394 # See https://docs.python.org/3/library/constants.html#NotImplemented 395 return NotImplemented 396 397 def __matmul__(self, o): 398 try: 399 return self._get_as_operand().__matmul__(o) 400 except AttributeError: 401 # See https://docs.python.org/3/library/constants.html#NotImplemented 402 return NotImplemented 403 404 def __rmatmul__(self, o): 405 try: 406 return self._get_as_operand().__rmatmul__(o) 407 except AttributeError: 408 # See https://docs.python.org/3/library/constants.html#NotImplemented 409 return NotImplemented 410 411 # TODO(josh11b): Even more operator overloads. 412 413 414class PerReplica(DistributedValues, composite_tensor.CompositeTensor): 415 """Holds a map from replica to unsynchronized values.""" 416 417 @property 418 def _type_spec(self): 419 return PerReplicaSpec( 420 *(type_spec.type_spec_from_value(v) for v in self._values)) 421 422 @property 423 def values(self): 424 """Returns the per replica values.""" 425 return self._values 426 427 428class PerReplicaSpec(type_spec.TypeSpec): 429 """Type specification for a `PerReplica`.""" 430 431 __slots__ = ["_value_specs"] 432 433 value_type = property(lambda self: PerReplica) 434 435 def __init__(self, *value_specs): 436 self._value_specs = tuple(value_specs) 437 438 def _serialize(self): 439 return self._value_specs 440 441 @property 442 def _component_specs(self): 443 return self._value_specs 444 445 def _to_components(self, value): 446 replica_context = ds_context.get_replica_context() 447 if replica_context is not None and replica_context.num_replicas_in_sync > 1: 448 raise ValueError( 449 "Flattening a PerReplica to components is not supported in replica " 450 "context.") 451 return value._values # pylint: disable=protected-access 452 453 def _from_components(self, tensor_list): 454 return PerReplica(tensor_list) 455 456 457# Note that unlike PerReplica, Mirrored values inherit from 458# DistributedDelegate and so can be used directly in cross-replica mode. 459# TODO(tomhennigan) Should this extend CompositeTensor? 460class Mirrored(DistributedDelegate): 461 """Holds a map from replica to values which are kept in sync.""" 462 463 def _get_cross_replica(self): 464 return self._get_on_device_or_primary() 465 466 def _as_graph_element(self): 467 obj = self._get() 468 conv_fn = getattr(obj, "_as_graph_element", None) 469 if conv_fn and callable(conv_fn): 470 return conv_fn() 471 return obj 472 473 474class DistributedVarOp(object): 475 """A class that looks like `tf.Operation`.""" 476 477 def __init__(self, name, graph, traceback, typ): 478 self.name = name 479 self.graph = graph 480 self.traceback = traceback 481 self.type = typ 482 483 def __eq__(self, o): 484 if not isinstance(o, self.__class__): 485 raise NotImplementedError 486 return (self.name == o.name and self.graph == o.graph and 487 self.traceback == o.traceback and self.type == o.type) 488 489 def __hash__(self): 490 return hash((self.name, self.graph, tuple(self.traceback), self.type)) 491 492 493class DistributedVariable(DistributedDelegate, variables_lib.Variable, 494 core.Tensor): 495 """Holds a map from replica to variables.""" 496 497 def __init__(self, strategy, values, aggregation, var_policy=None): 498 if (aggregation == variables_lib.VariableAggregation.MEAN and 499 not values[0].dtype.is_floating): 500 raise ValueError( 501 "creating distributed tf.Variable with aggregation=MEAN and a " 502 "non-floating dtype is not supported, please use a different " 503 "aggregation or dtype") 504 self._distribute_strategy = strategy 505 self._aggregation = aggregation 506 super(DistributedVariable, self).__init__(values) 507 self._common_name = self._primary.name.split(":")[0] 508 # Use a weakref to make it easy to map from the contained values 509 # to the container without introducing a reference cycle. 510 for v in values: 511 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 512 513 # Packed variable is used to reduce the overhead of function execution. 514 # For a DistributedVariable, only one variable handle is captured into a 515 # function graph. It's only supported in eager mode. 516 if ops.executing_eagerly_outside_functions() and getattr( 517 strategy, "_enable_packed_variable_in_eager_mode", False): 518 name = "%s/packed/" % self._common_name 519 self._packed_var = packed.PackedDistributedVariable(values, name=name) 520 else: 521 self._packed_var = None 522 523 # tf.keras keeps track of variables initialized using this attribute. When 524 # tf.keras gets the default session, it initializes all uninitialized vars. 525 # We need to make _keras_initialized a member of DistributedVariable because 526 # without this it will use `__getattr__` which will delegate to a component 527 # variable. 528 self._keras_initialized = False 529 # Typically, a `DistributedVariable`'s initializer is composed of the 530 # initializers of the components variables. However, in some cases, such as 531 # when restoring from a checkpoint, we may set the _initializer_op 532 # property on the entire `DistributedVariable`. 533 self._initializer_op = None 534 # Set a VariablePolicy which decides how we replicate/aggregate the given 535 # variable. 536 self._policy = var_policy 537 538 def __deepcopy__(self, memo): 539 """Perform a deepcopy of the `DistributedVariable`. 540 541 Unlike the deepcopy of a regular tf.Variable, this keeps the original 542 strategy and devices of the `DistributedVariable`. To avoid confusion 543 with the behavior of deepcopy on a regular `Variable` (which does 544 copy into new devices), we only allow a deepcopy of a `DistributedVariable` 545 within its originating strategy scope. 546 547 Args: 548 memo: The memoization object for `deepcopy`. 549 550 Returns: 551 A deep copy of the current `DistributedVariable`. 552 553 Raises: 554 RuntimeError: If trying to deepcopy into a different strategy. 555 """ 556 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 557 new_values = [] 558 559 for value in self._values: 560 with ops.device(value.device): 561 new_values.append(copy.deepcopy(value, memo)) 562 563 copied_variable = type(self)( 564 strategy=self._distribute_strategy, 565 values=new_values, 566 aggregation=self._aggregation, 567 var_policy=copy.deepcopy(self._policy, memo)) 568 569 memo[id(self)] = copied_variable 570 571 return copied_variable 572 573 def _use_packed_variable(self): 574 # Don't use packed variable when under a SaveContext to avoid explicit 575 # device placement on variable consuming ops. 576 return self._packed_var is not None and not save_context.in_save_context() 577 578 def is_initialized(self, name=None): 579 """Identifies if all the component variables are initialized. 580 581 Args: 582 name: Name of the final `logical_and` op. 583 584 Returns: 585 The op that evaluates to True or False depending on if all the 586 component variables are initialized. 587 """ 588 if values_util.is_saving_non_distributed(): 589 return self._primary.is_initialized() 590 if self._use_packed_variable(): 591 return self._packed_var.is_initialized() 592 result = self._primary.is_initialized() 593 # We iterate through the list of values except the last one to allow us to 594 # name the final `logical_and` op the same name that is passed by the user 595 # to the `is_initialized` op. For distributed variables, the 596 # `is_initialized` op is a `logical_and` op. 597 for v in self._values[1:-1]: 598 result = math_ops.logical_and(result, v.is_initialized()) 599 result = math_ops.logical_and( 600 result, self._values[-1].is_initialized(), name=name) 601 return result 602 603 @property 604 def initializer(self): 605 if values_util.is_saving_non_distributed(): 606 return self._primary.initializer 607 if self._initializer_op: 608 init_op = self._initializer_op 609 else: 610 # return grouped ops of all the var initializations of component values of 611 # the mirrored variable 612 init_op = control_flow_ops.group( 613 tuple(v.initializer for v in self._values)) 614 return init_op 615 616 def initialized_value(self): 617 return self._get_on_device_or_primary().initialized_value() 618 619 @property 620 def initial_value(self): 621 return self._get_on_device_or_primary().initial_value 622 623 @property 624 def constraint(self): 625 return self._primary.constraint 626 627 @property 628 def graph(self): 629 return self._primary.graph 630 631 @property 632 def _shared_name(self): 633 return self._common_name 634 635 @property 636 def _unique_id(self): 637 return self._primary._unique_id # pylint: disable=protected-access 638 639 @property 640 def _graph_key(self): 641 """Lets Optimizers know which graph this variable is from.""" 642 return self._primary._graph_key # pylint: disable=protected-access 643 644 @property 645 def name(self): 646 return self._primary.name 647 648 @property 649 def dtype(self): 650 return self._primary.dtype 651 652 @property 653 def shape(self): 654 return self._primary.shape 655 656 @property 657 def synchronization(self): 658 return self._primary.synchronization 659 660 @property 661 def aggregation(self): 662 return self._aggregation 663 664 @property 665 def _packed_variable(self): 666 if self._use_packed_variable(): 667 return self._packed_var 668 return None 669 670 @property 671 def handle(self): 672 if values_util.is_saving_non_distributed(): 673 return self._primary.handle 674 replica_id = values_util.get_current_replica_id_as_int() 675 if replica_id is None: 676 raise ValueError( 677 "DistributedVariable.handle is not available outside the replica " 678 "context or a `tf.distribute.Strategy.update()` call.") 679 else: 680 if self._use_packed_variable(): 681 return self._packed_var.handle 682 return self._values[replica_id].handle 683 684 def eval(self, session=None): 685 return self._get_on_device_or_primary().eval(session) 686 687 @property 688 def _save_slice_info(self): 689 return self._primary._save_slice_info # pylint: disable=protected-access 690 691 def _get_save_slice_info(self): 692 return self._primary._get_save_slice_info() # pylint: disable=protected-access 693 694 def _set_save_slice_info(self, save_slice_info): 695 for v in self._values: 696 v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access 697 698 @property 699 def device(self): 700 return self._get_on_device_or_primary().device 701 702 @property 703 def trainable(self): 704 return self._primary.trainable 705 706 @property 707 def distribute_strategy(self): 708 return self._distribute_strategy 709 710 def get_shape(self): 711 return self._primary.get_shape() 712 713 def to_proto(self, export_scope=None): 714 return self._primary.to_proto(export_scope=export_scope) 715 716 @property 717 def op(self): 718 if values_util.is_saving_non_distributed(): 719 return self._primary.op 720 # We want cross-replica code that does some var.op.X calls 721 # to work (even if the current device isn't in self._devices), but 722 # other uses of var.op in a cross-replica context to fail. 723 if ds_context.in_cross_replica_context(): 724 return DistributedVarOp(self._primary.op.name, self._primary.op.graph, 725 self._primary.op.traceback, self._primary.op.type) 726 return self._get().op 727 728 @property 729 def _in_graph_mode(self): 730 return self._primary._in_graph_mode # pylint: disable=protected-access 731 732 def _get_replica(self, replica_id): 733 """Returns the value on a device with the given replica_id.""" 734 if self._use_packed_variable(): 735 return self._packed_var.on_device(self._devices[replica_id]) 736 return self._values[replica_id] 737 738 def _get(self): 739 """Returns the value for the current device or raises a ValueError.""" 740 if values_util.is_saving_non_distributed(): 741 return self._primary 742 replica_id = values_util.get_current_replica_id_as_int() 743 if replica_id is None: 744 return self._get_cross_replica() 745 else: 746 return self._get_replica(replica_id) 747 748 def _get_on_device_or_primary(self): 749 """Returns value in same replica or device if possible, else the _primary.""" 750 if values_util.is_saving_non_distributed(): 751 return self._primary 752 replica_id = values_util.get_current_replica_id_as_int() 753 if replica_id is None: 754 # Try to find a value on the current device. 755 current_device = device_util.canonicalize(device_util.current()) 756 for i, value in enumerate(self._values): 757 if device_util.canonicalize(value.device) == current_device: 758 return self._get_replica(i) 759 return self._get_replica(0) 760 else: 761 return self._get_replica(replica_id) 762 763 def read_value(self): 764 if values_util.is_saving_non_distributed(): 765 return self._primary.read_value() 766 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 767 return array_ops.identity(self._get()) 768 769 def value(self): 770 if values_util.is_saving_non_distributed(): 771 return self._primary.value() 772 if self._policy: 773 return self._policy.value(self) 774 return self._get_on_device_or_primary().value() 775 776 def numpy(self): 777 if context.executing_eagerly(): 778 return self.read_value().numpy() 779 else: 780 raise NotImplementedError("DistributedVariable.numpy() is only available " 781 "when eager execution is enabled.") 782 783 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 784 if values_util.is_saving_non_distributed(): 785 return self._primary.assign_sub(value, use_locking, name, read_value) 786 if self._policy: 787 return self._policy.assign_sub( 788 self, 789 value, 790 use_locking=use_locking, 791 name=name, 792 read_value=read_value) 793 return values_util.on_write_assign_sub( 794 self, value, use_locking=use_locking, name=name, read_value=read_value) 795 796 def assign_add(self, value, use_locking=False, name=None, read_value=True): 797 if values_util.is_saving_non_distributed(): 798 return self._primary.assign_add(value, use_locking, name, read_value) 799 if self._policy: 800 return self._policy.assign_add( 801 self, 802 value, 803 use_locking=use_locking, 804 name=name, 805 read_value=read_value) 806 return values_util.on_write_assign_add( 807 self, value, use_locking=use_locking, name=name, read_value=read_value) 808 809 def assign(self, value, use_locking=False, name=None, read_value=True): 810 if values_util.is_saving_non_distributed(): 811 return self._primary.assign(value, use_locking, name, read_value) 812 if self._policy: 813 return self._policy.assign( 814 self, 815 value, 816 use_locking=use_locking, 817 name=name, 818 read_value=read_value) 819 return values_util.on_write_assign( 820 self, value, use_locking=use_locking, name=name, read_value=read_value) 821 822 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 823 if values_util.is_saving_non_distributed(): 824 return self._primary.scatter_sub(sparse_delta, use_locking, name) 825 if self._policy: 826 return self._policy.scatter_sub( 827 self, sparse_delta, use_locking=use_locking, name=name) 828 return values_util.scatter_sub( 829 self, sparse_delta, use_locking=use_locking, name=name) 830 831 def scatter_add(self, sparse_delta, use_locking=False, name=None): 832 if values_util.is_saving_non_distributed(): 833 return self._primary.scatter_add(sparse_delta, use_locking, name) 834 if self._policy: 835 return self._policy.scatter_add( 836 self, sparse_delta, use_locking=use_locking, name=name) 837 return values_util.scatter_add( 838 self, sparse_delta, use_locking=use_locking, name=name) 839 840 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 841 if values_util.is_saving_non_distributed(): 842 return self._primary.scatter_mul(sparse_delta, use_locking, name) 843 if self._policy: 844 return self._policy.scatter_mul( 845 self, sparse_delta, use_locking=use_locking, name=name) 846 return values_util.scatter_mul( 847 self, sparse_delta, use_locking=use_locking, name=name) 848 849 def scatter_div(self, sparse_delta, use_locking=False, name=None): 850 if values_util.is_saving_non_distributed(): 851 return self._primary.scatter_div(sparse_delta, use_locking, name) 852 if self._policy: 853 return self._policy.scatter_div( 854 self, sparse_delta, use_locking=use_locking, name=name) 855 return values_util.scatter_div( 856 self, sparse_delta, use_locking=use_locking, name=name) 857 858 def scatter_min(self, sparse_delta, use_locking=False, name=None): 859 if values_util.is_saving_non_distributed(): 860 return self._primary.scatter_min(sparse_delta, use_locking, name) 861 if self._policy: 862 return self._policy.scatter_min( 863 self, sparse_delta, use_locking=use_locking, name=name) 864 return values_util.scatter_min( 865 self, sparse_delta, use_locking=use_locking, name=name) 866 867 def scatter_max(self, sparse_delta, use_locking=False, name=None): 868 if values_util.is_saving_non_distributed(): 869 return self._primary.scatter_max(sparse_delta, use_locking, name) 870 if self._policy: 871 return self._policy.scatter_max( 872 self, sparse_delta, use_locking=use_locking, name=name) 873 return values_util.scatter_max( 874 self, sparse_delta, use_locking=use_locking, name=name) 875 876 def scatter_update(self, sparse_delta, use_locking=False, name=None): 877 if values_util.is_saving_non_distributed(): 878 return self._primary.scatter_update(sparse_delta, use_locking, name) 879 if self._policy: 880 return self._policy.scatter_update( 881 self, sparse_delta, use_locking=use_locking, name=name) 882 return values_util.scatter_update( 883 self, sparse_delta, use_locking=use_locking, name=name) 884 885 def _gather_saveables_for_checkpoint(self): 886 """Overrides Trackable method. 887 888 This allows both name-based and object-based save and restore of 889 DistributedVariables. 890 891 Returns: 892 A dictionary mapping attribute names to `SaveableObject` factories. 893 """ 894 895 def _saveable_factory(name=self._common_name): 896 return _DistributedVariableSaveable(self, self._primary, name) 897 898 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 899 900 def _as_graph_element(self): 901 if values_util.is_saving_non_distributed(): 902 return self._primary._as_graph_element() # pylint: disable=protected-access 903 if self._policy: 904 return self._policy._as_graph_element(self) # pylint: disable=protected-access 905 906 raise NotImplementedError( 907 "DistributedVariable._as_graph_element requires a valid " 908 "VariablePolicy. Please set the policy via the `var_policy` argument " 909 "in the constructor, or override this method in sub-classes which " 910 "support cross-replica accesses.") 911 912 def _get_cross_replica(self): 913 if values_util.is_saving_non_distributed(): 914 return self._primary 915 if self._policy: 916 return self._policy._get_cross_replica(self) # pylint: disable=protected-access 917 918 raise NotImplementedError( 919 "DistributedVariable._get_cross_replica requires a valid " 920 "VariablePolicy. Please set the policy via the `var_policy` argument " 921 "in the constructor, or override this method in sub-classes which " 922 "support cross-replica accesses.") 923 924 def _update_cross_replica(self, update_fn, value, **kwargs): 925 """Applies updates across replicas. 926 927 Args: 928 update_fn: A callable to pass to `strategy.extended.update` to update the 929 variable. It should has the same signature as `Variable.assign()`. 930 value: value to be passed to `update_fn`. 931 **kwargs: remaining arguments to `update_fn`. 932 933 Returns: 934 Updated variable or `tf.Operation`. 935 """ 936 values_util.mark_as_unsaveable() 937 return self.distribute_strategy.extended.update( 938 self, update_fn, args=(value,), kwargs=kwargs, group=True) 939 940 def _update_replica(self, update_fn, value, **kwargs): 941 """Applies updates in one replica. 942 943 Args: 944 update_fn: A callable to update the variable. It should has the same 945 signature as `Variable.assign()`. 946 value: value to be passed to `update_fn`. 947 **kwargs: remaining arguments to `update_fn`. 948 949 Returns: 950 Updated variable or `tf.Operation`. 951 """ 952 if self._policy: 953 return self._policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access 954 raise NotImplementedError( 955 "DistributedVariable._update_replica requires a valid VariablePolicy. " 956 "Please set the policy via the `var_policy` argument in the " 957 "constructor, or override this method in sub-classes which support " 958 "cross-replica accesses.") 959 960 def _update(self, update_fn, value, **kwargs): 961 """Applies updates depending on the context. 962 963 The method calls `_update_replica` in replica context, 964 `_update_cross_replica` in cross replica context, and `update_fn` in update 965 context. 966 967 If `read_value` is True, the method returns the updated Variable. If 968 `read_value` is False, the method returns the update `tf.Operation`. 969 970 Args: 971 update_fn: A callable to pass to `strategy.extended.update` to update the 972 variable. It should have the same signature as `Variable.assign()`. 973 value: value to be passed to `update_fn`. 974 **kwargs: keyword arguments to `update_fn`. 975 976 Returns: 977 Updated variable or `tf.Operation`. 978 979 """ 980 if values_util.is_saving_non_distributed(): 981 return update_fn(self._primary, value, **kwargs) 982 with ds_context.enter_or_assert_strategy(self.distribute_strategy): 983 if ds_context.in_cross_replica_context(): 984 update_replica_id = distribute_lib.get_update_replica_id() 985 if update_replica_id is not None: 986 replica_value = self._get_replica(update_replica_id) 987 return update_fn(replica_value, value, **kwargs) 988 return self._update_cross_replica(update_fn, value, **kwargs) 989 else: 990 values_util.assert_replica_context(self.distribute_strategy) 991 return self._update_replica(update_fn, value, **kwargs) 992 993 def _should_act_as_resource_variable(self): 994 """Pass resource_variable_ops.is_resource_variable check.""" 995 pass 996 997 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 998 """Converts a variable to a tensor.""" 999 if values_util.is_saving_non_distributed(): 1000 return ops.convert_to_tensor( 1001 self._primary, dtype=dtype, name=name, as_ref=as_ref) 1002 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1003 return ops.convert_to_tensor( 1004 self._get(), dtype=dtype, name=name, as_ref=as_ref) 1005 1006 def _map_resources(self, save_options): 1007 """For implementing `Trackable`.""" 1008 # Initialize for self._primary first, so that obj_map[self._primary] and 1009 # resource_map[self._primary.handle] contain mapped values. 1010 obj_map, resource_map = self._primary._map_resources(save_options) # pylint:disable=protected-access 1011 for v in [v for v in self._values if v != self._primary]: 1012 1013 if (save_options.experimental_variable_policy # pylint:disable=protected-access 1014 ._expand_distributed_variables()): 1015 v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access 1016 obj_map.update(v_obj_map) 1017 resource_map.update(v_resource_map) 1018 else: 1019 obj_map[v] = obj_map[self._primary] 1020 resource_map[v.handle] = resource_map[self._primary.handle] 1021 obj_map[self] = obj_map[self._primary] 1022 resource_map[self] = resource_map[self._primary.handle] 1023 if self._packed_var is not None: 1024 resource_map[self._packed_var.packed_handle] = resource_map[ 1025 self._primary.handle] 1026 return obj_map, resource_map 1027 1028 def _write_object_proto(self, proto, options): 1029 """Update a SavedObject proto for the caller. 1030 1031 If a DistributedVariable object supports this method, it will be called when 1032 saving with a pre-built `SavedObject` proto representing the object, plus an 1033 instance of `SaveOptions`. This method is then free to modify that proto 1034 instance. 1035 1036 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1037 write out information about their components to the 1038 `experimental_distributed_variable_components` field of a 1039 `SavedVariable` (depending on the `SaveOptions` variable policy). 1040 1041 Args: 1042 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1043 will be a `SavedVariable` instance. 1044 options: A `SaveOptions` instance. 1045 """ 1046 resource_variable_ops.write_object_proto_for_resource_variable( 1047 self, proto, options) 1048 if self._policy: 1049 if self._policy._is_mirrored(): # pylint: disable=protected-access 1050 self._policy._write_object_proto(self, proto, options) # pylint: disable=protected-access 1051 1052 1053# We extend from `saveable_object.SaveableObject` instead of 1054# `saveable_object_util.ResourceVariableSaveable` since we need to read the 1055# value of ONREAD variables when saving. `SaveableObject` provides a way to 1056# specify the function to run to get the value of the variable or tensor at 1057# saving time. We can use this for both ON_READ and ON_WRITE variables. 1058# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic 1059# if possible. 1060class _DistributedVariableSaveable(saveable_object.SaveableObject): 1061 """Class for defining how to restore a DistributedVariable.""" 1062 1063 def __init__(self, distributed_variable, primary_variable, name): 1064 self._distributed_variable = distributed_variable 1065 if not self._distributed_variable._policy: 1066 raise ValueError( 1067 "The VariablePolicy of the argument `distributed_variable` must be " 1068 "set to create a _DistributedVariableSaveable. Please set it via " 1069 "the `var_policy` argument in the constructor of DistributedVariable." 1070 ) 1071 tensor, spec = distributed_variable._policy.get_saveable( 1072 distributed_variable, primary_variable, name) 1073 super(_DistributedVariableSaveable, self).__init__(tensor, spec, name) 1074 1075 def restore(self, restored_tensors, restored_shapes): 1076 """Restore the same value into all variables.""" 1077 tensor, = restored_tensors 1078 return self._distributed_variable._policy.get_restore_ops( # pylint: disable=protected-access 1079 self._distributed_variable, tensor) 1080 1081 1082class _MirroredSaveable(saveable_object.SaveableObject): 1083 """Class for defining how to restore a MirroredVariable.""" 1084 1085 def __init__(self, mirrored_variable, primary_variable, name): 1086 self._mirrored_variable = mirrored_variable 1087 tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable, 1088 primary_variable, name) 1089 super(_MirroredSaveable, self).__init__(tensor, spec, name) 1090 1091 def restore(self, restored_tensors, restored_shapes): 1092 """Restore the same value into all variables.""" 1093 tensor, = restored_tensors 1094 return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor) 1095 1096 1097class MirroredVariable(DistributedVariable, Mirrored): 1098 """Holds a map from replica to variables whose values are kept in sync.""" 1099 1100 def _update_replica(self, update_fn, value, **kwargs): 1101 return _on_write_update_replica(self, update_fn, value, **kwargs) 1102 1103 def scatter_min(self, *args, **kwargs): 1104 if values_util.is_saving_non_distributed(): 1105 return self._primary.scatter_min(*args, **kwargs) 1106 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1107 self._aggregation != vs.VariableAggregation.NONE): 1108 raise NotImplementedError( 1109 values_util.scatter_error_msg.format( 1110 op_name="scatter_min", aggregation=self._aggregation)) 1111 return super(MirroredVariable, self).scatter_min(*args, **kwargs) 1112 1113 def scatter_max(self, *args, **kwargs): 1114 if values_util.is_saving_non_distributed(): 1115 return self._primary.scatter_max(*args, **kwargs) 1116 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1117 self._aggregation != vs.VariableAggregation.NONE): 1118 raise NotImplementedError( 1119 values_util.scatter_error_msg.format( 1120 op_name="scatter_max", aggregation=self._aggregation)) 1121 return super(MirroredVariable, self).scatter_max(*args, **kwargs) 1122 1123 def scatter_update(self, *args, **kwargs): 1124 if values_util.is_saving_non_distributed(): 1125 return self._primary.scatter_update(*args, **kwargs) 1126 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1127 self._aggregation != vs.VariableAggregation.NONE): 1128 raise NotImplementedError( 1129 values_util.scatter_error_msg.format( 1130 op_name="scatter_update", aggregation=self._aggregation)) 1131 return super(MirroredVariable, self).scatter_update(*args, **kwargs) 1132 1133 def _get_cross_replica(self): 1134 # Return identity, to avoid directly exposing the variable to the user and 1135 # allowing it to be modified by mistake. 1136 return array_ops.identity(Mirrored._get_cross_replica(self)) 1137 1138 def _as_graph_element(self): 1139 return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 1140 1141 def _gather_saveables_for_checkpoint(self): 1142 """Overrides Trackable method. 1143 1144 This allows both name-based and object-based save and restore of 1145 MirroredVariables. 1146 1147 Returns: 1148 A dictionary mapping attribute names to `SaveableObject` factories. 1149 """ 1150 1151 def _saveable_factory(name=self._common_name): 1152 return _MirroredSaveable(self, self._primary, name) 1153 1154 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1155 1156 def _write_object_proto(self, proto, options): 1157 """Update a SavedObject proto for the caller. 1158 1159 If a DistributedVariable object supports this method, it will be called when 1160 saving with a pre-built `SavedObject` proto representing the object, plus an 1161 instance of `SaveOptions`. This method is then free to modify that proto 1162 instance. 1163 1164 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1165 write out information about their components to the 1166 `experimental_distributed_variable_components` field of a 1167 `SavedVariable` (depending on the `SaveOptions` variable policy). 1168 1169 Args: 1170 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1171 will be a `SavedVariable` instance. 1172 options: A `SaveOptions` instance. 1173 """ 1174 super(MirroredVariable, self)._write_object_proto(proto, options) 1175 values_util.write_object_proto(self, proto, options) 1176 1177 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1178 """Converts a variable to a tensor.""" 1179 # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ 1180 # and ON_WRITE. 1181 # Try to avoid assignments to and other mutations of MirroredVariable 1182 # state except through a DistributionStrategy.extended.update() or any of 1183 # the `assign*` and `scatter*` calls. 1184 if as_ref: 1185 # A TF 1.x case where the variable is a boolean variable and used like: 1186 # tf.cond(v, true_fn, false_fn). 1187 raise ValueError( 1188 "You may be using variable created under distribute strategy in TF " 1189 "1.x control flows. Try explicitly converting the variable to Tensor " 1190 "using variable.read_value(), or switch to TF 2.x.") 1191 return ops.convert_to_tensor( 1192 self._get(), dtype=dtype, name=name, as_ref=as_ref) 1193 1194 1195class _SyncOnReadSaveable(saveable_object.SaveableObject): 1196 """Class for defining how to restore a SyncOnReadVariable.""" 1197 1198 def __init__(self, sync_on_read_variable, name): 1199 self._sync_on_read_variable = sync_on_read_variable 1200 tensor, spec = values_util.get_on_read_saveable( 1201 sync_on_read_variable, sync_on_read_variable._primary, name) 1202 1203 super(_SyncOnReadSaveable, self).__init__(tensor, spec, name) 1204 1205 def restore(self, restored_tensors, restored_shapes): 1206 """Restore the same value into all variables.""" 1207 tensor, = restored_tensors 1208 return values_util.get_on_read_restore_ops( 1209 self._sync_on_read_variable, tensor, 1210 self._sync_on_read_variable.aggregation) 1211 1212 1213class SyncOnReadVariable(DistributedVariable): 1214 """Holds a map from replica to variables whose values are reduced on save.""" 1215 1216 def _update_replica(self, update_fn, value, **kwargs): 1217 return update_fn(self._get_on_device_or_primary(), value, **kwargs) 1218 1219 def _get(self): 1220 """Returns the value of SyncOnReadVariable based on surrounding context. 1221 1222 If called under a non-default replica-context, returns the corresponding 1223 variable on that replica. 1224 If called under default replica-context or cross-replica context, returns 1225 the synced value. 1226 """ 1227 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1228 return super(SyncOnReadVariable, self)._get() 1229 1230 # TODO(b/154017756): Make assign behaivor in cross replica context consistent 1231 # with MirroredVariable. 1232 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 1233 if values_util.is_saving_non_distributed(): 1234 return self._primary.assign_sub(value, use_locking, name, read_value) 1235 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1236 if (ds_context.in_cross_replica_context() and 1237 not values_util.in_replica_update_context()): 1238 values_util.mark_as_unsaveable() 1239 return values_util.on_read_assign_sub_cross_replica( 1240 self, value, read_value=read_value) 1241 else: 1242 return super(SyncOnReadVariable, 1243 self).assign_sub(value, use_locking, name, read_value) 1244 1245 def assign_add(self, value, use_locking=False, name=None, read_value=True): 1246 if values_util.is_saving_non_distributed(): 1247 return self._primary.assign_add(value, use_locking, name, read_value) 1248 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1249 if (ds_context.in_cross_replica_context() and 1250 not values_util.in_replica_update_context()): 1251 values_util.mark_as_unsaveable() 1252 return values_util.on_read_assign_add_cross_replica( 1253 self, value, read_value=read_value) 1254 else: 1255 return super(SyncOnReadVariable, 1256 self).assign_add(value, use_locking, name, read_value) 1257 1258 def assign(self, value, use_locking=False, name=None, read_value=True): 1259 if values_util.is_saving_non_distributed(): 1260 return self._primary.assign(value, use_locking, name, read_value) 1261 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1262 if (ds_context.in_cross_replica_context() and 1263 not values_util.in_replica_update_context()): 1264 values_util.mark_as_unsaveable() 1265 return values_util.on_read_assign_cross_replica( 1266 self, value, read_value=read_value) 1267 else: 1268 return super(SyncOnReadVariable, self).assign(value, use_locking, name, 1269 read_value) 1270 1271 def _scatter_not_implemented(self, method): 1272 raise NotImplementedError( 1273 f"Variables with `synchronization=ON_READ` doesn't support `{method}`") 1274 1275 def scatter_sub(self, *args, **kwargs): 1276 if values_util.is_saving_non_distributed(): 1277 return self._primary.scatter_sub(*args, **kwargs) 1278 self._scatter_not_implemented("scatter_sub") 1279 1280 def scatter_add(self, *args, **kwargs): 1281 if values_util.is_saving_non_distributed(): 1282 return self._primary.scatter_add(*args, **kwargs) 1283 self._scatter_not_implemented("scatter_add") 1284 1285 def scatter_mul(self, *args, **kwargs): 1286 if values_util.is_saving_non_distributed(): 1287 return self._primary.scatter_mul(*args, **kwargs) 1288 self._scatter_not_implemented("scatter_mul") 1289 1290 def scatter_div(self, *args, **kwargs): 1291 if values_util.is_saving_non_distributed(): 1292 return self._primary.scatter_div(*args, **kwargs) 1293 self._scatter_not_implemented("scatter_div") 1294 1295 def scatter_min(self, *args, **kwargs): 1296 if values_util.is_saving_non_distributed(): 1297 return self._primary.scatter_min(*args, **kwargs) 1298 self._scatter_not_implemented("scatter_min") 1299 1300 def scatter_max(self, *args, **kwargs): 1301 if values_util.is_saving_non_distributed(): 1302 return self._primary.scatter_max(*args, **kwargs) 1303 self._scatter_not_implemented("scatter_max") 1304 1305 def scatter_update(self, *args, **kwargs): 1306 if values_util.is_saving_non_distributed(): 1307 return self._primary.scatter_update(*args, **kwargs) 1308 self._scatter_not_implemented("scatter_update") 1309 1310 def value(self): 1311 if ds_context.in_variable_sync_on_read_context(): 1312 raise NotImplementedError( 1313 "call `variable.value()` inside variable_sync_on_read_context is not " 1314 "supported") 1315 if values_util.is_saving_non_distributed(): 1316 return self._primary.value() 1317 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1318 if (ds_context.in_cross_replica_context() and 1319 not values_util.in_replica_update_context()): 1320 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1321 return self._get_replica(0).value() 1322 return self._get_cross_replica() 1323 else: 1324 # _get_on_device_or_primary() returns a Variable. 1325 return self._get_on_device_or_primary().value() 1326 1327 def read_value(self): 1328 if ds_context.in_variable_sync_on_read_context(): 1329 raise NotImplementedError( 1330 "call `variable.read_value()` inside variable_sync_on_read_context is" 1331 " not supported") 1332 return super().read_value() 1333 1334 def _get_cross_replica(self): 1335 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1336 # Consider returning a tensor value here to make the return value of 1337 # _get_cross_replica consistent. 1338 return self._get_replica(0) 1339 if self._aggregation == vs.VariableAggregation.SUM: 1340 values_util.mark_as_unsaveable() 1341 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1342 return self._distribute_strategy.reduce( 1343 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 1344 self, 1345 axis=None) 1346 1347 def _as_graph_element(self): 1348 if values_util.is_saving_non_distributed(): 1349 return self._primary._as_graph_element() # pylint: disable=protected-access 1350 # pylint: disable=protected-access 1351 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1352 if ds_context.in_cross_replica_context(): 1353 return ops.convert_to_tensor(self._get_cross_replica()) 1354 return self._get()._as_graph_element() 1355 1356 def _gather_saveables_for_checkpoint(self): 1357 """Overrides Trackable method. 1358 1359 This allows both name-based and object-based save and restore of 1360 `SyncOnReadVariable`s. 1361 1362 Returns: 1363 A dictionary mapping attribute names to `SaveableObject` factories. 1364 """ 1365 1366 def _saveable_factory(name=self._common_name): 1367 return _SyncOnReadSaveable(self, name) 1368 1369 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1370 1371 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1372 """Converts a SyncOnReadVariable to a tensor.""" 1373 if values_util.is_saving_non_distributed(): 1374 return ops.convert_to_tensor( 1375 self._primary, dtype=dtype, name=name, as_ref=as_ref) 1376 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1377 replica_context = ds_context.get_replica_context() 1378 if (replica_context is not None and 1379 ds_context.in_variable_sync_on_read_context()): 1380 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1381 return ops.convert_to_tensor( 1382 self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref) 1383 if self._aggregation == vs.VariableAggregation.SUM: 1384 values_util.mark_as_unsaveable() 1385 # pylint: disable=protected-access 1386 reduced = ( 1387 replica_context.strategy.extended._replica_ctx_all_reduce( 1388 reduce_util.ReduceOp.from_variable_aggregation( 1389 self._aggregation), 1390 self._get().read_value())) 1391 return ops.convert_to_tensor( 1392 reduced, dtype=dtype, name=name, as_ref=as_ref) 1393 1394 return ops.convert_to_tensor( 1395 self._get(), dtype=dtype, name=name, as_ref=as_ref) 1396 1397 1398# Register a conversion functions which reads the value of the variable, 1399# allowing instances of the class to be used as tensors. 1400# DistributedVariable 1401def _tensor_conversion_distributed_var(var, 1402 dtype=None, 1403 name=None, 1404 as_ref=False): 1405 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1406 1407 1408ops.register_tensor_conversion_function(DistributedVariable, 1409 _tensor_conversion_distributed_var) 1410 1411 1412# MirroredVariables 1413def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): 1414 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1415 1416 1417ops.register_tensor_conversion_function(MirroredVariable, 1418 _tensor_conversion_mirrored) 1419 1420 1421# Mirrored Values 1422def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False): 1423 return ops.convert_to_tensor( 1424 value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1425 1426 1427ops.register_tensor_conversion_function(Mirrored, 1428 _tensor_conversion_mirrored_val) 1429 1430 1431# SyncOnReadVariables 1432def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): 1433 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1434 1435 1436ops.register_tensor_conversion_function(SyncOnReadVariable, 1437 _tensor_conversion_sync_on_read) 1438 1439 1440class VariablePolicy(object): 1441 """Policy defining synchronization and aggregation of a distributed variable. 1442 1443 Given `synchronization` and `aggregation` parameters set on a `tf.Variable` 1444 during variable creation within `tf.distribute` scope, `tf.distribute` creates 1445 an appropriate policy object and assigns it to the distributed variable. All 1446 variable operations are delegated to the respective policy object. 1447 """ 1448 1449 def __init__(self, aggregation): 1450 self._aggregation = aggregation 1451 1452 def value(self): 1453 raise NotImplementedError( 1454 "VariablePolicy.value should be overriden by sub-classes.") 1455 1456 def _is_mirrored(self): 1457 raise NotImplementedError( 1458 "VariablePolicy._is_mirrored should be overriden by sub-classes.") 1459 1460 def _as_graph_element(self, _): 1461 raise NotImplementedError( 1462 "VariablePolicy._as_graph_element should be overriden by sub-classes.") 1463 1464 def _get_cross_replica(self, var): 1465 raise NotImplementedError( 1466 "VariablePolicy._get_cross_replica should be overriden by sub-classes.") 1467 1468 def _update_replica(self, var, update_fn, value, **kwargs): 1469 raise NotImplementedError( 1470 "VariablePolicy._update_replica should be overriden by sub-classes.") 1471 1472 1473class OnReadPolicy(VariablePolicy): 1474 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 1475 1476 This policy is created when `synchronization` is set to 1477 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 1478 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 1479 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 1480 scope. 1481 """ 1482 1483 def _is_mirrored(self): 1484 return False 1485 1486 def value(self, var): 1487 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1488 if (ds_context.in_cross_replica_context() and 1489 not values_util.in_replica_update_context()): 1490 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1491 return var._get_replica(0).value() # pylint: disable=protected-access 1492 return var._get_cross_replica() # pylint: disable=protected-access 1493 else: 1494 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 1495 1496 def _as_graph_element(self, var): 1497 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1498 if ds_context.in_cross_replica_context(): 1499 return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access 1500 return var._get()._as_graph_element() # pylint: disable=protected-access 1501 1502 def _get_cross_replica(self, var): 1503 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1504 return var._get_replica(0) # pylint: disable=protected-access 1505 if self._aggregation == vs.VariableAggregation.SUM: 1506 values_util.mark_as_unsaveable() 1507 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1508 return var.distribute_strategy.reduce( 1509 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 1510 var, 1511 axis=None) 1512 1513 def _update_replica(self, var, update_fn, value, **kwargs): 1514 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 1515 1516 def _scatter_not_implemented(self, method): 1517 raise NotImplementedError(f"ON_READ variables doesn't support `{method}` " 1518 "in cross replica context") 1519 1520 def assign_sub(self, 1521 var, 1522 value, 1523 use_locking=False, 1524 name=None, 1525 read_value=True): 1526 """Subtracts a value from this variable.""" 1527 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1528 if (ds_context.in_cross_replica_context() and 1529 not values_util.in_replica_update_context()): 1530 values_util.mark_as_unsaveable() 1531 return values_util.on_read_assign_sub_cross_replica( 1532 var, value, read_value=read_value) 1533 else: 1534 return values_util.on_write_assign_sub( 1535 var, 1536 value, 1537 use_locking=use_locking, 1538 name=name, 1539 read_value=read_value) 1540 1541 def assign_add(self, 1542 var, 1543 value, 1544 use_locking=False, 1545 name=None, 1546 read_value=True): 1547 """Adds a value to this variable.""" 1548 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1549 if (ds_context.in_cross_replica_context() and 1550 not values_util.in_replica_update_context()): 1551 values_util.mark_as_unsaveable() 1552 return values_util.on_read_assign_add_cross_replica( 1553 var, value, read_value=read_value) 1554 else: 1555 return values_util.on_write_assign_add( 1556 var, 1557 value, 1558 use_locking=use_locking, 1559 name=name, 1560 read_value=read_value) 1561 1562 def assign(self, var, value, use_locking=False, name=None, read_value=True): 1563 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1564 if (ds_context.in_cross_replica_context() and 1565 not values_util.in_replica_update_context()): 1566 values_util.mark_as_unsaveable() 1567 return values_util.on_read_assign_cross_replica( 1568 var, value, read_value=read_value) 1569 else: 1570 return values_util.on_write_assign( 1571 var, 1572 value, 1573 use_locking=use_locking, 1574 name=name, 1575 read_value=read_value) 1576 1577 def scatter_sub(self, *args, **kwargs): 1578 del args, kwargs 1579 self._scatter_not_implemented("scatter_sub") 1580 1581 def scatter_add(self, *args, **kwargs): 1582 del args, kwargs 1583 self._scatter_not_implemented("scatter_add") 1584 1585 def scatter_mul(self, *args, **kwargs): 1586 del args, kwargs 1587 self._scatter_not_implemented("scatter_mul") 1588 1589 def scatter_div(self, *args, **kwargs): 1590 del args, kwargs 1591 self._scatter_not_implemented("scatter_div") 1592 1593 def scatter_min(self, *args, **kwargs): 1594 del args, kwargs 1595 self._scatter_not_implemented("scatter_min") 1596 1597 def scatter_max(self, *args, **kwargs): 1598 del args, kwargs 1599 self._scatter_not_implemented("scatter_max") 1600 1601 def scatter_update(self, *args, **kwargs): 1602 del args, kwargs 1603 self._scatter_not_implemented("scatter_update") 1604 1605 def get_saveable(self, var, primary_var, name): 1606 """Create a saveable object for the given variable.""" 1607 return values_util.get_on_read_saveable(var, primary_var, name) 1608 1609 def get_restore_ops(self, var, tensor): 1610 """Restore the same value into all variables.""" 1611 return values_util.get_on_read_restore_ops(var, tensor, self._aggregation) 1612 1613 1614class OnWritePolicy(VariablePolicy): 1615 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 1616 1617 This policy is created when the following `synchronization` and `aggregation` 1618 parameters are specified when creating a `tf.Variable` in `tf.distribute` 1619 scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` 1620 or `tf.VariableSynchronization.AUTO`. 1621 """ 1622 1623 def _is_mirrored(self): 1624 return True 1625 1626 def value(self, var): 1627 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 1628 1629 def _as_graph_element(self, var): 1630 return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 1631 1632 def _get_cross_replica(self, var): 1633 # Return identity, to avoid directly exposing the variable to the user and 1634 # allowing it to be modified by mistake. 1635 return array_ops.identity(var._get_on_device_or_primary()) # pylint: disable=protected-access 1636 1637 def _update_replica(self, var, update_fn, value, **kwargs): 1638 if var.aggregation == variables_lib.VariableAggregation.NONE: 1639 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 1640 return _on_write_update_replica(var, update_fn, value, **kwargs) 1641 1642 def assign(self, var, value, use_locking=False, name=None, read_value=True): 1643 return values_util.on_write_assign( 1644 var, value, use_locking=use_locking, name=name, read_value=read_value) 1645 1646 def assign_add(self, 1647 var, 1648 value, 1649 use_locking=False, 1650 name=None, 1651 read_value=True): 1652 return values_util.on_write_assign_add( 1653 var, value, use_locking=use_locking, name=name, read_value=read_value) 1654 1655 def assign_sub(self, 1656 var, 1657 value, 1658 use_locking=False, 1659 name=None, 1660 read_value=True): 1661 return values_util.on_write_assign_sub( 1662 var, value, use_locking=use_locking, name=name, read_value=read_value) 1663 1664 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 1665 return values_util.scatter_sub( 1666 var, sparse_delta, use_locking=use_locking, name=name) 1667 1668 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 1669 return values_util.scatter_add( 1670 var, sparse_delta, use_locking=use_locking, name=name) 1671 1672 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 1673 return values_util.scatter_mul( 1674 var, sparse_delta, use_locking=use_locking, name=name) 1675 1676 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 1677 return values_util.scatter_div( 1678 var, sparse_delta, use_locking=use_locking, name=name) 1679 1680 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 1681 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1682 self._aggregation != vs.VariableAggregation.NONE): 1683 raise NotImplementedError( 1684 values_util.scatter_error_msg.format( 1685 op_name="scatter_min", aggregation=self._aggregation)) 1686 return values_util.scatter_min( 1687 var, sparse_delta, use_locking=use_locking, name=name) 1688 1689 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 1690 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1691 self._aggregation != vs.VariableAggregation.NONE): 1692 raise NotImplementedError( 1693 values_util.scatter_error_msg.format( 1694 op_name="scatter_max", aggregation=self._aggregation)) 1695 return values_util.scatter_max( 1696 var, sparse_delta, use_locking=use_locking, name=name) 1697 1698 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 1699 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1700 self._aggregation != vs.VariableAggregation.NONE): 1701 raise NotImplementedError( 1702 values_util.scatter_error_msg.format( 1703 op_name="scatter_update", aggregation=self._aggregation)) 1704 return values_util.scatter_update( 1705 var, sparse_delta, use_locking=use_locking, name=name) 1706 1707 def get_saveable(self, var, primary_var, name): 1708 """Saveable ops for AUTO variables.""" 1709 return values_util.get_on_write_saveable(var, primary_var, name) 1710 1711 def get_restore_ops(self, var, tensor): 1712 return values_util.get_on_write_restore_ops(var, tensor) 1713 1714 def _write_object_proto(self, var, proto, options): 1715 """Update a SavedObject proto for the caller. 1716 1717 If a DistributedVariable object supports this method, it will be called when 1718 saving with a pre-built `SavedObject` proto representing the object, plus an 1719 instance of `SaveOptions`. This method is then free to modify that proto 1720 instance. 1721 1722 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1723 write out information about their components to the 1724 `experimental_distributed_variable_components` field of a 1725 `SavedVariable` (depending on the `SaveOptions` variable policy). 1726 1727 Args: 1728 var : A DistributedVariable object 1729 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1730 will be a `SavedVariable` instance. 1731 options: A `SaveOptions` instance. 1732 """ 1733 values_util.write_object_proto(var, proto, options) 1734 1735 1736class PerWorkerResource(): 1737 """A per-worker CachableResource class for non-ParameterServer strategy. 1738 1739 Resources that populate `host_to_resources` should be instances of classes 1740 subclassing CachableResource, although currently it's only used and tested for 1741 StaticHashTable with TPUStrategy. 1742 """ 1743 1744 def __init__(self, strategy, host_to_resources): 1745 self._strategy = strategy 1746 self._host_to_resources = host_to_resources 1747 1748 def __getattribute__(self, name): 1749 if name not in ("__init__", "__getattribute__", "_host_to_resources", 1750 "_strategy", "local_resource"): 1751 return getattr(self.local_resource(), name) 1752 return super(PerWorkerResource, self).__getattribute__(name) 1753 1754 def local_resource(self): 1755 """Returns the resource on the local worker.""" 1756 current_device = device_util.canonicalize(device_util.current()) 1757 host_device = device_util.canonicalize( 1758 device_util.get_host_for_device(current_device)) 1759 return self._host_to_resources.get( 1760 host_device, 1761 self._host_to_resources[next(iter(self._host_to_resources))]) 1762