1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import weakref 23 24from tensorflow.python.distribute import device_util 25from tensorflow.python.distribute import distribute_lib 26from tensorflow.python.distribute import distribution_strategy_context as ds_context 27from tensorflow.python.distribute import packed_distributed_variable as packed 28from tensorflow.python.distribute import reduce_util 29from tensorflow.python.distribute import values_util 30from tensorflow.python.eager import context 31from tensorflow.python.framework import composite_tensor 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import type_spec 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import variable_scope as vs 38from tensorflow.python.ops import variables as variables_lib 39from tensorflow.python.saved_model import save_context 40from tensorflow.python.training.saving import saveable_object 41from tensorflow.python.training.tracking import base as trackable 42from tensorflow.python.types import core 43from tensorflow.python.util.tf_export import tf_export 44 45 46def _on_write_update_replica(var, update_fn, value, **kwargs): 47 """Updates variables with ON_WRITE synchronization in replica context.""" 48 if var.aggregation == vs.VariableAggregation.NONE: 49 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 50 51 def merge_fn(strategy, value, **kwargs): 52 """Aggregate values and update all variables in cross replica context.""" 53 # Don't allow MEAN with non float dtype, since it may cause unexpected 54 # precision loss. Python3 and NumPy automatically upcast integers to 55 # float in division, but we should always preserve the type. 56 # 57 # Note that to be backward compatible we allow the case when the value 58 # is *always* the same on each replica. I.E. value is not a 59 # PerReplica. Refer to regroup() to see how values are grouped. 60 if var.aggregation == vs.VariableAggregation.MEAN and ( 61 not var.dtype.is_floating) and isinstance(value, PerReplica): 62 raise ValueError( 63 "Cannot update non-float variables with " 64 "tf.VariableAggregation.MEAN aggregation in replica context. " 65 "Either change the variable dtype to float or update it in " 66 "cross-replica context.") 67 68 assert strategy == var.distribute_strategy 69 v = values_util.apply_aggregation(strategy, value, var.aggregation, var) 70 return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access 71 72 return ds_context.get_replica_context().merge_call( 73 merge_fn, args=(value,), kwargs=kwargs) 74 75 76@tf_export("distribute.DistributedValues", v1=[]) 77class DistributedValues(object): 78 """Base class for representing distributed values. 79 80 A subclass instance of `tf.distribute.DistributedValues` is created when 81 creating variables within a distribution strategy, iterating a 82 `tf.distribute.DistributedDataset` or through `tf.distribute.Strategy.run`. 83 This base class should never be instantiated directly. 84 `tf.distribute.DistributedValues` contains a value per replica. Depending on 85 the subclass, the values could either be synced on update, synced on demand, 86 or never synced. 87 88 `tf.distribute.DistributedValues` can be reduced to obtain single value across 89 replicas, as input into `tf.distribute.Strategy.run` or the per-replica values 90 inspected using `tf.distribute.Strategy.experimental_local_results`. 91 92 Example usage: 93 94 1. Created from a `tf.distribute.DistributedDataset`: 95 96 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 97 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 98 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 99 >>> distributed_values = next(dataset_iterator) 100 101 2. Returned by `run`: 102 103 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 104 >>> @tf.function 105 ... def run(): 106 ... ctx = tf.distribute.get_replica_context() 107 ... return ctx.replica_id_in_sync_group 108 >>> distributed_values = strategy.run(run) 109 110 3. As input into `run`: 111 112 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 113 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 114 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 115 >>> distributed_values = next(dataset_iterator) 116 >>> @tf.function 117 ... def run(input): 118 ... return input + 1.0 119 >>> updated_value = strategy.run(run, args=(distributed_values,)) 120 121 4. Reduce value: 122 123 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 124 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 125 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 126 >>> distributed_values = next(dataset_iterator) 127 >>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM, 128 ... distributed_values, 129 ... axis = 0) 130 131 5. Inspect local replica values: 132 133 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 134 >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 135 >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 136 >>> per_replica_values = strategy.experimental_local_results( 137 ... distributed_values) 138 >>> per_replica_values 139 (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>, 140 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>) 141 142 """ 143 144 def __init__(self, values): 145 """Should only be called by subclass __init__.""" 146 self._values = tuple(values) 147 148 def _get(self): 149 """Returns the value for the current device or raises a ValueError.""" 150 replica_id = values_util.get_current_replica_id_as_int() 151 if replica_id is None: 152 return self._get_cross_replica() 153 else: 154 return self._values[replica_id] 155 156 def _get_cross_replica(self): 157 raise NotImplementedError( 158 "This method should be overridden by sub-classes which support cross-" 159 "replica accesses.") 160 161 def _get_on_device_or_primary(self): 162 """Returns value in same replica or device if possible, else the _primary.""" 163 replica_id = values_util.get_current_replica_id_as_int() 164 if replica_id is None: 165 # Try to find a value on the current device. 166 current_device = device_util.canonicalize(device_util.current()) 167 for value in self._values: 168 if device_util.canonicalize(value.device) == current_device: 169 return value 170 return self._primary 171 else: 172 return self._values[replica_id] 173 174 @property 175 def _primary(self): 176 """Returns a representative component.""" 177 return self._values[0] 178 179 @property 180 def _devices(self): 181 return tuple(v.device for v in self._values) 182 183 def __str__(self): 184 debug_str = ",\n".join( 185 " %d: %s" % (i, v) for i, v in enumerate(self._values)) 186 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 187 188 def __repr__(self): 189 debug_repr = ",\n".join( 190 " %d: %r" % (i, v) for i, v in enumerate(self._values)) 191 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 192 193 194# NOTE(josh11b,apassos): It would be great if we could inspect the values this was 195# initialized with and use that to generate the overloaded operators here. 196# Unfortunately, Python's rules for special methods don't allow this, see 197# https://docs.python.org/3/reference/datamodel.html#special-method-names 198# "if a class defines a method named __getitem__(), and x is an instance of 199# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)." 200# In particular, these special methods don't go through __getattr__, and 201# it will only use those methods if they are defined in the class, not the 202# object. 203class DistributedDelegate(DistributedValues): 204 """A map from device to values; acts as the same type as the values.""" 205 206 def __getattr__(self, name): 207 # The '_use_resource_variables' and the attrs starts with '_self' are used 208 # for restoring the saved_model proto, and '_attribute_sentinel' is used for 209 # Layer tracking. At the point these attrs are queried, the variable has not 210 # been initialized. Thus it should not query those of the underlying 211 # components. 212 if name.startswith("_self_") or name in ("_use_resource_variables", 213 "_attribute_sentinel", 214 "_distributed_container"): 215 return super(DistributedDelegate, self).__getattr__(name) 216 217 # This allows copy.copy(DistributedDelegate). When copying an object, 218 # copy.copy doesn't invoke its __init__ method, instead it makes a new 219 # empty object, then copies the attributes over. copy.copy looks for 220 # attributes like "__getstate__" in case the object implements its custom 221 # copying. Since DistributedDelegate doesn't have those attributes defined, 222 # __getattr__ will be invoked, which tries to access "_values" attributes, 223 # but that doesn't exist either because this is an empty object, and again 224 # __getattr__ is invoked, leading to an infinite recursion. 225 if name == "_values": 226 raise AttributeError() 227 228 # TODO(priyag): This needs to be made robust against pitfalls from mix use 229 # __getattr__ and @property. See b/120402273. 230 return getattr(self._get(), name) 231 232 @property 233 def values(self): 234 """Returns the per replica values.""" 235 return self._values 236 237 def _get_as_operand(self): 238 """Returns the value for operations for the current device. 239 240 Some implementations, e.g. `TPUMirroredVariable`, are not able to return the 241 value type within a replica context. They can, however, return a value that 242 can be used by the operations below. 243 """ 244 return self._get() 245 246 # pylint: disable=multiple-statements 247 def __add__(self, o): 248 return self._get_as_operand() + o 249 250 def __radd__(self, o): 251 return o + self._get_as_operand() 252 253 def __sub__(self, o): 254 return self._get_as_operand() - o 255 256 def __rsub__(self, o): 257 return o - self._get_as_operand() 258 259 def __mul__(self, o): 260 return self._get_as_operand() * o 261 262 def __rmul__(self, o): 263 return o * self._get_as_operand() 264 265 def __truediv__(self, o): 266 return self._get_as_operand() / o 267 268 def __rtruediv__(self, o): 269 return o / self._get_as_operand() 270 271 def __floordiv__(self, o): 272 return self._get_as_operand() // o 273 274 def __rfloordiv__(self, o): 275 return o // self._get_as_operand() 276 277 def __mod__(self, o): 278 return self._get_as_operand() % o 279 280 def __rmod__(self, o): 281 return o % self._get_as_operand() 282 283 def __lt__(self, o): 284 return self._get_as_operand() < o 285 286 def __le__(self, o): 287 return self._get_as_operand() <= o 288 289 def __gt__(self, o): 290 return self._get_as_operand() > o 291 292 def __ge__(self, o): 293 return self._get_as_operand() >= o 294 295 def __and__(self, o): 296 return self._get_as_operand() & o 297 298 def __rand__(self, o): 299 return o & self._get_as_operand() 300 301 def __or__(self, o): 302 return self._get_as_operand() | o 303 304 def __ror__(self, o): 305 return o | self._get_as_operand() 306 307 def __xor__(self, o): 308 return self._get_as_operand() ^ o 309 310 def __rxor__(self, o): 311 return o ^ self._get_as_operand() 312 313 def __getitem__(self, o): 314 return self._get_as_operand()[o] 315 316 def __pow__(self, o, modulo=None): 317 return pow(self._get_as_operand(), o, modulo) 318 319 def __rpow__(self, o): 320 return pow(o, self._get_as_operand()) 321 322 def __invert__(self): 323 return ~self._get_as_operand() 324 325 def __neg__(self): 326 return -self._get_as_operand() 327 328 def __abs__(self): 329 return abs(self._get_as_operand()) 330 331 def __div__(self, o): 332 try: 333 return self._get_as_operand().__div__(o) 334 except AttributeError: 335 # See https://docs.python.org/3/library/constants.html#NotImplemented 336 return NotImplemented 337 338 def __rdiv__(self, o): 339 try: 340 return self._get_as_operand().__rdiv__(o) 341 except AttributeError: 342 # See https://docs.python.org/3/library/constants.html#NotImplemented 343 return NotImplemented 344 345 def __matmul__(self, o): 346 try: 347 return self._get_as_operand().__matmul__(o) 348 except AttributeError: 349 # See https://docs.python.org/3/library/constants.html#NotImplemented 350 return NotImplemented 351 352 def __rmatmul__(self, o): 353 try: 354 return self._get_as_operand().__rmatmul__(o) 355 except AttributeError: 356 # See https://docs.python.org/3/library/constants.html#NotImplemented 357 return NotImplemented 358 359 # TODO(josh11b): Even more operator overloads. 360 361 362class PerReplica(DistributedValues, composite_tensor.CompositeTensor): 363 """Holds a map from replica to unsynchronized values.""" 364 365 @property 366 def _type_spec(self): 367 return PerReplicaSpec( 368 *(type_spec.type_spec_from_value(v) for v in self._values)) 369 370 @property 371 def values(self): 372 """Returns the per replica values.""" 373 return self._values 374 375 376class PerReplicaSpec(type_spec.TypeSpec): 377 """Type specification for a `PerReplica`.""" 378 379 __slots__ = ["_value_specs"] 380 381 value_type = property(lambda self: PerReplica) 382 383 def __init__(self, *value_specs): 384 self._value_specs = tuple(value_specs) 385 386 def _serialize(self): 387 return self._value_specs 388 389 @property 390 def _component_specs(self): 391 return self._value_specs 392 393 def _to_components(self, value): 394 replica_context = ds_context.get_replica_context() 395 if replica_context is not None and replica_context.num_replicas_in_sync > 1: 396 raise ValueError( 397 "Flattening a PerReplica to components is not supported in replica " 398 "context.") 399 return value._values # pylint: disable=protected-access 400 401 def _from_components(self, tensor_list): 402 return PerReplica(tensor_list) 403 404 405# Note that unlike PerReplica, Mirrored values inherit from 406# DistributedDelegate and so can be used directly in cross-replica mode. 407# TODO(tomhennigan) Should this extend CompositeTensor? 408class Mirrored(DistributedDelegate): 409 """Holds a map from replica to values which are kept in sync.""" 410 411 def _get_cross_replica(self): 412 return self._get_on_device_or_primary() 413 414 def _as_graph_element(self): 415 obj = self._get() 416 conv_fn = getattr(obj, "_as_graph_element", None) 417 if conv_fn and callable(conv_fn): 418 return conv_fn() 419 return obj 420 421 422class DistributedVarOp(object): 423 """A class that looks like `tf.Operation`.""" 424 425 def __init__(self, name, graph, traceback, typ): 426 self.name = name 427 self.graph = graph 428 self.traceback = traceback 429 self.type = typ 430 431 def __eq__(self, o): 432 if not isinstance(o, self.__class__): 433 raise NotImplementedError 434 return (self.name == o.name and self.graph == o.graph and 435 self.traceback == o.traceback and self.type == o.type) 436 437 def __hash__(self): 438 return hash((self.name, self.graph, tuple(self.traceback), self.type)) 439 440 441class DistributedVariable(DistributedDelegate, variables_lib.Variable, 442 core.Tensor): 443 """Holds a map from replica to variables.""" 444 445 def __init__(self, strategy, values, aggregation, var_policy=None): 446 if (aggregation == variables_lib.VariableAggregation.MEAN and 447 not values[0].dtype.is_floating): 448 raise ValueError( 449 "creating distributed tf.Variable with aggregation=MEAN and a " 450 "non-floating dtype is not supported, please use a different " 451 "aggregation or dtype") 452 self._distribute_strategy = strategy 453 self._aggregation = aggregation 454 super(DistributedVariable, self).__init__(values) 455 self._common_name = self._primary.name.split(":")[0] 456 # Use a weakref to make it easy to map from the contained values 457 # to the container without introducing a reference cycle. 458 for v in values: 459 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 460 461 # Packed variable is used to reduce the overhead of function execution. 462 # For a DistributedVariable, only one variable handle is captured into a 463 # function graph. It's only supported in eager mode. 464 if ops.executing_eagerly_outside_functions() and getattr( 465 strategy, "_enable_packed_variable_in_eager_mode", False): 466 name = "%s/packed/" % self._common_name 467 self._packed_var = packed.PackedDistributedVariable(values, name=name) 468 else: 469 self._packed_var = None 470 471 # tf.keras keeps track of variables initialized using this attribute. When 472 # tf.keras gets the default session, it initializes all uninitialized vars. 473 # We need to make _keras_initialized a member of DistributedVariable because 474 # without this it will use `__getattr__` which will delegate to a component 475 # variable. 476 self._keras_initialized = False 477 # Typically, a `DistributedVariable`'s initializer is composed of the 478 # initializers of the components variables. However, in some cases, such as 479 # when restoring from a checkpoint, we may set the _initializer_op 480 # property on the entire `DistributedVariable`. 481 self._initializer_op = None 482 # Set a VariablePolicy which decides how we replicate/aggregate the given 483 # variable. 484 self._policy = var_policy 485 486 def __deepcopy__(self, memo): 487 """Perform a deepcopy of the `DistributedVariable`. 488 489 Unlike the deepcopy of a regular tf.Variable, this keeps the original 490 strategy and devices of the `DistributedVariable`. To avoid confusion 491 with the behavior of deepcopy on a regular `Variable` (which does 492 copy into new devices), we only allow a deepcopy of a `DistributedVariable` 493 within its originating strategy scope. 494 495 Args: 496 memo: The memoization object for `deepcopy`. 497 498 Returns: 499 A deep copy of the current `DistributedVariable`. 500 501 Raises: 502 RuntimeError: If trying to deepcopy into a different strategy. 503 """ 504 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 505 new_values = [] 506 507 for value in self._values: 508 with ops.device(value.device): 509 new_values.append(copy.deepcopy(value, memo)) 510 511 copied_variable = type(self)( 512 strategy=self._distribute_strategy, 513 values=new_values, 514 aggregation=self._aggregation, 515 var_policy=copy.deepcopy(self._policy, memo)) 516 517 memo[id(self)] = copied_variable 518 519 return copied_variable 520 521 def _use_packed_variable(self): 522 # Don't use packed variable when under a SaveContext to avoid explicit 523 # device placement on variable consuming ops. 524 return self._packed_var is not None and not save_context.in_save_context() 525 526 def is_initialized(self, name=None): 527 """Identifies if all the component variables are initialized. 528 529 Args: 530 name: Name of the final `logical_and` op. 531 532 Returns: 533 The op that evaluates to True or False depending on if all the 534 component variables are initialized. 535 """ 536 if values_util.is_saving_non_distributed(): 537 return self._primary.is_initialized() 538 if self._use_packed_variable(): 539 return self._packed_var.is_initialized() 540 result = self._primary.is_initialized() 541 # We iterate through the list of values except the last one to allow us to 542 # name the final `logical_and` op the same name that is passed by the user 543 # to the `is_initialized` op. For distributed variables, the 544 # `is_initialized` op is a `logical_and` op. 545 for v in self._values[1:-1]: 546 result = math_ops.logical_and(result, v.is_initialized()) 547 result = math_ops.logical_and( 548 result, self._values[-1].is_initialized(), name=name) 549 return result 550 551 @property 552 def initializer(self): 553 if values_util.is_saving_non_distributed(): 554 return self._primary.initializer 555 if self._initializer_op: 556 init_op = self._initializer_op 557 else: 558 # return grouped ops of all the var initializations of component values of 559 # the mirrored variable 560 init_op = control_flow_ops.group( 561 tuple(v.initializer for v in self._values)) 562 return init_op 563 564 def initialized_value(self): 565 return self._get_on_device_or_primary().initialized_value() 566 567 @property 568 def initial_value(self): 569 return self._get_on_device_or_primary().initial_value 570 571 @property 572 def constraint(self): 573 return self._primary.constraint 574 575 @property 576 def graph(self): 577 return self._primary.graph 578 579 @property 580 def _shared_name(self): 581 return self._common_name 582 583 @property 584 def _unique_id(self): 585 return self._primary._unique_id # pylint: disable=protected-access 586 587 @property 588 def _graph_key(self): 589 """Lets Optimizers know which graph this variable is from.""" 590 return self._primary._graph_key # pylint: disable=protected-access 591 592 @property 593 def name(self): 594 return self._primary.name 595 596 @property 597 def dtype(self): 598 return self._primary.dtype 599 600 @property 601 def shape(self): 602 return self._primary.shape 603 604 @property 605 def synchronization(self): 606 return self._primary.synchronization 607 608 @property 609 def aggregation(self): 610 return self._aggregation 611 612 @property 613 def _packed_variable(self): 614 if self._use_packed_variable(): 615 return self._packed_var 616 return None 617 618 @property 619 def handle(self): 620 if values_util.is_saving_non_distributed(): 621 return self._primary.handle 622 replica_id = values_util.get_current_replica_id_as_int() 623 if replica_id is None: 624 raise ValueError("`handle` is not available outside the replica context" 625 " or a `tf.distribute.Strategy.update()` call.") 626 else: 627 if self._use_packed_variable(): 628 return self._packed_var.handle 629 return self._values[replica_id].handle 630 631 def eval(self, session=None): 632 return self._get_on_device_or_primary().eval(session) 633 634 @property 635 def _save_slice_info(self): 636 return self._primary._save_slice_info # pylint: disable=protected-access 637 638 def _get_save_slice_info(self): 639 return self._primary._get_save_slice_info() # pylint: disable=protected-access 640 641 def _set_save_slice_info(self, save_slice_info): 642 for v in self._values: 643 v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access 644 645 @property 646 def device(self): 647 return self._get_on_device_or_primary().device 648 649 @property 650 def trainable(self): 651 return self._primary.trainable 652 653 @property 654 def distribute_strategy(self): 655 return self._distribute_strategy 656 657 def get_shape(self): 658 return self._primary.get_shape() 659 660 def to_proto(self, export_scope=None): 661 return self._primary.to_proto(export_scope=export_scope) 662 663 @property 664 def op(self): 665 if values_util.is_saving_non_distributed(): 666 return self._primary.op 667 # We want cross-replica code that does some var.op.X calls 668 # to work (even if the current device isn't in self._devices), but 669 # other uses of var.op in a cross-replica context to fail. 670 if ds_context.in_cross_replica_context(): 671 return DistributedVarOp(self._primary.op.name, self._primary.op.graph, 672 self._primary.op.traceback, self._primary.op.type) 673 return self._get().op 674 675 @property 676 def _in_graph_mode(self): 677 return self._primary._in_graph_mode # pylint: disable=protected-access 678 679 def _get_replica(self, replica_id): 680 """Returns the value on a device with the given replica_id.""" 681 if self._use_packed_variable(): 682 return self._packed_var.on_device(self._devices[replica_id]) 683 return self._values[replica_id] 684 685 def _get(self): 686 """Returns the value for the current device or raises a ValueError.""" 687 if values_util.is_saving_non_distributed(): 688 return self._primary 689 replica_id = values_util.get_current_replica_id_as_int() 690 if replica_id is None: 691 return self._get_cross_replica() 692 else: 693 return self._get_replica(replica_id) 694 695 def _get_on_device_or_primary(self): 696 """Returns value in same replica or device if possible, else the _primary.""" 697 if values_util.is_saving_non_distributed(): 698 return self._primary 699 replica_id = values_util.get_current_replica_id_as_int() 700 if replica_id is None: 701 # Try to find a value on the current device. 702 current_device = device_util.canonicalize(device_util.current()) 703 for i, value in enumerate(self._values): 704 if device_util.canonicalize(value.device) == current_device: 705 return self._get_replica(i) 706 return self._get_replica(0) 707 else: 708 return self._get_replica(replica_id) 709 710 def read_value(self): 711 if values_util.is_saving_non_distributed(): 712 return self._primary.read_value() 713 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 714 return array_ops.identity(self._get()) 715 716 def value(self): 717 if values_util.is_saving_non_distributed(): 718 return self._primary.value() 719 if self._policy: 720 return self._policy.value(self) 721 return self._get_on_device_or_primary().value() 722 723 def numpy(self): 724 if context.executing_eagerly(): 725 return self.read_value().numpy() 726 else: 727 raise NotImplementedError( 728 "numpy() is only available when eager execution is enabled.") 729 730 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 731 if values_util.is_saving_non_distributed(): 732 return self._primary.assign_sub(value, use_locking, name, read_value) 733 if self._policy: 734 return self._policy.assign_sub( 735 self, 736 value, 737 use_locking=use_locking, 738 name=name, 739 read_value=read_value) 740 return values_util.on_write_assign_sub( 741 self, value, use_locking=use_locking, name=name, read_value=read_value) 742 743 def assign_add(self, value, use_locking=False, name=None, read_value=True): 744 if values_util.is_saving_non_distributed(): 745 return self._primary.assign_add(value, use_locking, name, read_value) 746 if self._policy: 747 return self._policy.assign_add( 748 self, 749 value, 750 use_locking=use_locking, 751 name=name, 752 read_value=read_value) 753 return values_util.on_write_assign_add( 754 self, value, use_locking=use_locking, name=name, read_value=read_value) 755 756 def assign(self, value, use_locking=False, name=None, read_value=True): 757 if values_util.is_saving_non_distributed(): 758 return self._primary.assign(value, use_locking, name, read_value) 759 if self._policy: 760 return self._policy.assign( 761 self, 762 value, 763 use_locking=use_locking, 764 name=name, 765 read_value=read_value) 766 return values_util.on_write_assign( 767 self, value, use_locking=use_locking, name=name, read_value=read_value) 768 769 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 770 if values_util.is_saving_non_distributed(): 771 return self._primary.scatter_sub(sparse_delta, use_locking, name) 772 if self._policy: 773 return self._policy.scatter_sub( 774 self, sparse_delta, use_locking=use_locking, name=name) 775 return values_util.scatter_sub( 776 self, sparse_delta, use_locking=use_locking, name=name) 777 778 def scatter_add(self, sparse_delta, use_locking=False, name=None): 779 if values_util.is_saving_non_distributed(): 780 return self._primary.scatter_add(sparse_delta, use_locking, name) 781 if self._policy: 782 return self._policy.scatter_add( 783 self, sparse_delta, use_locking=use_locking, name=name) 784 return values_util.scatter_add( 785 self, sparse_delta, use_locking=use_locking, name=name) 786 787 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 788 if values_util.is_saving_non_distributed(): 789 return self._primary.scatter_mul(sparse_delta, use_locking, name) 790 if self._policy: 791 return self._policy.scatter_mul( 792 self, sparse_delta, use_locking=use_locking, name=name) 793 return values_util.scatter_mul( 794 self, sparse_delta, use_locking=use_locking, name=name) 795 796 def scatter_div(self, sparse_delta, use_locking=False, name=None): 797 if values_util.is_saving_non_distributed(): 798 return self._primary.scatter_div(sparse_delta, use_locking, name) 799 if self._policy: 800 return self._policy.scatter_div( 801 self, sparse_delta, use_locking=use_locking, name=name) 802 return values_util.scatter_div( 803 self, sparse_delta, use_locking=use_locking, name=name) 804 805 def scatter_min(self, sparse_delta, use_locking=False, name=None): 806 if values_util.is_saving_non_distributed(): 807 return self._primary.scatter_min(sparse_delta, use_locking, name) 808 if self._policy: 809 return self._policy.scatter_min( 810 self, sparse_delta, use_locking=use_locking, name=name) 811 return values_util.scatter_min( 812 self, sparse_delta, use_locking=use_locking, name=name) 813 814 def scatter_max(self, sparse_delta, use_locking=False, name=None): 815 if values_util.is_saving_non_distributed(): 816 return self._primary.scatter_max(sparse_delta, use_locking, name) 817 if self._policy: 818 return self._policy.scatter_max( 819 self, sparse_delta, use_locking=use_locking, name=name) 820 return values_util.scatter_max( 821 self, sparse_delta, use_locking=use_locking, name=name) 822 823 def scatter_update(self, sparse_delta, use_locking=False, name=None): 824 if values_util.is_saving_non_distributed(): 825 return self._primary.scatter_update(sparse_delta, use_locking, name) 826 if self._policy: 827 return self._policy.scatter_update( 828 self, sparse_delta, use_locking=use_locking, name=name) 829 return values_util.scatter_update( 830 self, sparse_delta, use_locking=use_locking, name=name) 831 832 def _gather_saveables_for_checkpoint(self): 833 """Overrides Trackable method. 834 835 This allows both name-based and object-based save and restore of 836 DistributedVariables. 837 838 Returns: 839 A dictionary mapping attribute names to `SaveableObject` factories. 840 """ 841 842 def _saveable_factory(name=self._common_name): 843 return _DistributedVariableSaveable(self, self._primary, name) 844 845 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 846 847 def _as_graph_element(self): 848 if values_util.is_saving_non_distributed(): 849 return self._primary._as_graph_element() # pylint: disable=protected-access 850 if self._policy: 851 return self._policy._as_graph_element(self) # pylint: disable=protected-access 852 853 raise NotImplementedError("No policy set for calling _as_graph_element.") 854 855 def _get_cross_replica(self): 856 if values_util.is_saving_non_distributed(): 857 return self._primary 858 if self._policy: 859 return self._policy._get_cross_replica(self) # pylint: disable=protected-access 860 861 raise NotImplementedError( 862 "This method should be overridden by sub-classes which support cross-" 863 "replica accesses.") 864 865 def _update_cross_replica(self, update_fn, value, **kwargs): 866 """Applies updates across replicas. 867 868 Args: 869 update_fn: A callable to pass to `strategy.extended.update` to update the 870 variable. It should has the same signature as `Variable.assign()`. 871 value: value to be passed to `update_fn`. 872 **kwargs: remaining arguments to `update_fn`. 873 874 Returns: 875 Updated variable or `tf.Operation`. 876 """ 877 values_util.mark_as_unsaveable() 878 return self.distribute_strategy.extended.update( 879 self, update_fn, args=(value,), kwargs=kwargs, group=True) 880 881 def _update_replica(self, update_fn, value, **kwargs): 882 """Applies updates in one replica. 883 884 Args: 885 update_fn: A callable to update the variable. It should has the same 886 signature as `Variable.assign()`. 887 value: value to be passed to `update_fn`. 888 **kwargs: remaining arguments to `update_fn`. 889 890 Returns: 891 Updated variable or `tf.Operation`. 892 """ 893 if self._policy: 894 return self._policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access 895 raise NotImplementedError("should be implemented by subclass.") 896 897 def _update(self, update_fn, value, **kwargs): 898 """Applies updates depending on the context. 899 900 The method calls `_update_replica` in replica context, 901 `_update_cross_replica` in cross replica context, and `update_fn` in update 902 context. 903 904 If `read_value` is True, the method returns the updated Variable. If 905 `read_value` is False, the method returns the update `tf.Operation`. 906 907 Args: 908 update_fn: A callable to pass to `strategy.extended.update` to update the 909 variable. It should have the same signature as `Variable.assign()`. 910 value: value to be passed to `update_fn`. 911 **kwargs: keyword arguments to `update_fn`. 912 913 Returns: 914 Updated variable or `tf.Operation`. 915 916 """ 917 if values_util.is_saving_non_distributed(): 918 return update_fn(self._primary, value, **kwargs) 919 with ds_context.enter_or_assert_strategy(self.distribute_strategy): 920 if ds_context.in_cross_replica_context(): 921 update_replica_id = distribute_lib.get_update_replica_id() 922 if update_replica_id is not None: 923 replica_value = self._get_replica(update_replica_id) 924 return update_fn(replica_value, value, **kwargs) 925 return self._update_cross_replica(update_fn, value, **kwargs) 926 else: 927 values_util.assert_replica_context(self.distribute_strategy) 928 return self._update_replica(update_fn, value, **kwargs) 929 930 def _should_act_as_resource_variable(self): 931 """Pass resource_variable_ops.is_resource_variable check.""" 932 pass 933 934 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 935 """Converts a variable to a tensor.""" 936 if values_util.is_saving_non_distributed(): 937 return ops.convert_to_tensor( 938 self._primary, dtype=dtype, name=name, as_ref=as_ref) 939 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 940 return ops.convert_to_tensor( 941 self._get(), dtype=dtype, name=name, as_ref=as_ref) 942 943 def _map_resources(self, save_options): 944 """For implementing `Trackable`.""" 945 # Initialize for self._primary first, so that obj_map[self._primary] and 946 # resource_map[self._primary.handle] contain mapped values. 947 obj_map, resource_map = self._primary._map_resources(save_options) # pylint:disable=protected-access 948 for v in [v for v in self._values if v != self._primary]: 949 950 if (save_options.experimental_variable_policy # pylint:disable=protected-access 951 ._expand_distributed_variables()): 952 v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access 953 obj_map.update(v_obj_map) 954 resource_map.update(v_resource_map) 955 else: 956 obj_map[v] = obj_map[self._primary] 957 resource_map[v.handle] = resource_map[self._primary.handle] 958 obj_map[self] = obj_map[self._primary] 959 resource_map[self] = resource_map[self._primary.handle] 960 if self._packed_var is not None: 961 resource_map[self._packed_var.packed_handle] = resource_map[ 962 self._primary.handle] 963 return obj_map, resource_map 964 965 def _write_object_proto(self, proto, options): 966 """Update a SavedObject proto for the caller. 967 968 If a DistributedVariable object supports this method, it will be called when 969 saving with a pre-built `SavedObject` proto representing the object, plus an 970 instance of `SaveOptions`. This method is then free to modify that proto 971 instance. 972 973 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 974 write out information about their components to the 975 `experimental_distributed_variable_components` field of a 976 `SavedVariable` (depending on the `SaveOptions` variable policy). 977 978 Args: 979 proto: A pre-built `SavedObject` proto for this object. It is assumed this 980 will be a `SavedVariable` instance. 981 options: A `SaveOptions` instance. 982 """ 983 if self._policy: 984 if self._policy._is_mirrored(): # pylint: disable=protected-access 985 self._policy._write_object_proto(self, proto, options) # pylint: disable=protected-access 986 else: 987 self._write_object_proto(proto, options) 988 989 990# We extend from `saveable_object.SaveableObject` instead of 991# `saveable_object_util.ResourceVariableSaveable` since we need to read the 992# value of ONREAD variables when saving. `SaveableObject` provides a way to 993# specify the function to run to get the value of the variable or tensor at 994# saving time. We can use this for both ON_READ and ON_WRITE variables. 995# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic 996# if possible. 997class _DistributedVariableSaveable(saveable_object.SaveableObject): 998 """Class for defining how to restore a DistributedVariable.""" 999 1000 def __init__(self, distributed_variable, primary_variable, name): 1001 self._distributed_variable = distributed_variable 1002 if not self._distributed_variable._policy: 1003 raise ValueError("VariablePolicy has not been set for the distributed " 1004 "variable.") 1005 tensor, spec = distributed_variable._policy.get_saveable( 1006 distributed_variable, primary_variable, name) 1007 super(_DistributedVariableSaveable, self).__init__(tensor, spec, name) 1008 1009 def restore(self, restored_tensors, restored_shapes): 1010 """Restore the same value into all variables.""" 1011 tensor, = restored_tensors 1012 return self._distributed_variable._policy.get_restore_ops( # pylint: disable=protected-access 1013 self._distributed_variable, tensor) 1014 1015 1016class _MirroredSaveable(saveable_object.SaveableObject): 1017 """Class for defining how to restore a MirroredVariable.""" 1018 1019 def __init__(self, mirrored_variable, primary_variable, name): 1020 self._mirrored_variable = mirrored_variable 1021 tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable, 1022 primary_variable, 1023 name) 1024 super(_MirroredSaveable, self).__init__(tensor, spec, name) 1025 1026 def restore(self, restored_tensors, restored_shapes): 1027 """Restore the same value into all variables.""" 1028 tensor, = restored_tensors 1029 return values_util.get_on_write_restore_ops(self._mirrored_variable, 1030 tensor) 1031 1032 1033class MirroredVariable(DistributedVariable, Mirrored): 1034 """Holds a map from replica to variables whose values are kept in sync.""" 1035 1036 def _update_replica(self, update_fn, value, **kwargs): 1037 return _on_write_update_replica(self, update_fn, value, **kwargs) 1038 1039 def scatter_min(self, *args, **kwargs): 1040 if values_util.is_saving_non_distributed(): 1041 return self._primary.scatter_min(*args, **kwargs) 1042 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1043 self._aggregation != vs.VariableAggregation.NONE): 1044 raise NotImplementedError(values_util.scatter_error_msg.format( 1045 op_name="scatter_min", aggregation=self._aggregation)) 1046 return super(MirroredVariable, self).scatter_min(*args, **kwargs) 1047 1048 def scatter_max(self, *args, **kwargs): 1049 if values_util.is_saving_non_distributed(): 1050 return self._primary.scatter_max(*args, **kwargs) 1051 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1052 self._aggregation != vs.VariableAggregation.NONE): 1053 raise NotImplementedError(values_util.scatter_error_msg.format( 1054 op_name="scatter_max", aggregation=self._aggregation)) 1055 return super(MirroredVariable, self).scatter_max(*args, **kwargs) 1056 1057 def scatter_update(self, *args, **kwargs): 1058 if values_util.is_saving_non_distributed(): 1059 return self._primary.scatter_update(*args, **kwargs) 1060 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1061 self._aggregation != vs.VariableAggregation.NONE): 1062 raise NotImplementedError(values_util.scatter_error_msg.format( 1063 op_name="scatter_update", aggregation=self._aggregation)) 1064 return super(MirroredVariable, self).scatter_update(*args, **kwargs) 1065 1066 def _get_cross_replica(self): 1067 # Return identity, to avoid directly exposing the variable to the user and 1068 # allowing it to be modified by mistake. 1069 return array_ops.identity(Mirrored._get_cross_replica(self)) 1070 1071 def _as_graph_element(self): 1072 return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 1073 1074 def _gather_saveables_for_checkpoint(self): 1075 """Overrides Trackable method. 1076 1077 This allows both name-based and object-based save and restore of 1078 MirroredVariables. 1079 1080 Returns: 1081 A dictionary mapping attribute names to `SaveableObject` factories. 1082 """ 1083 1084 def _saveable_factory(name=self._common_name): 1085 return _MirroredSaveable(self, self._primary, name) 1086 1087 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1088 1089 def _write_object_proto(self, proto, options): 1090 """Update a SavedObject proto for the caller. 1091 1092 If a DistributedVariable object supports this method, it will be called when 1093 saving with a pre-built `SavedObject` proto representing the object, plus an 1094 instance of `SaveOptions`. This method is then free to modify that proto 1095 instance. 1096 1097 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1098 write out information about their components to the 1099 `experimental_distributed_variable_components` field of a 1100 `SavedVariable` (depending on the `SaveOptions` variable policy). 1101 1102 Args: 1103 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1104 will be a `SavedVariable` instance. 1105 options: A `SaveOptions` instance. 1106 """ 1107 values_util.write_object_proto(self, proto, options) 1108 1109 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1110 """Converts a variable to a tensor.""" 1111 # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ 1112 # and ON_WRITE. 1113 # Try to avoid assignments to and other mutations of MirroredVariable 1114 # state except through a DistributionStrategy.extended.update() or any of 1115 # the `assign*` and `scatter*` calls. 1116 if as_ref: 1117 # A TF 1.x case where the variable is a boolean variable and used like: 1118 # tf.cond(v, true_fn, false_fn). 1119 raise ValueError( 1120 "You may be using variable created under distribute strategy in TF " 1121 "1.x control flows. Try explicitly converting the variable to Tensor " 1122 "using variable.read_value(), or switch to TF 2.x.") 1123 return ops.convert_to_tensor( 1124 self._get(), dtype=dtype, name=name, as_ref=as_ref) 1125 1126 1127class _SyncOnReadSaveable(saveable_object.SaveableObject): 1128 """Class for defining how to restore a SyncOnReadVariable.""" 1129 1130 def __init__(self, sync_on_read_variable, name): 1131 self._sync_on_read_variable = sync_on_read_variable 1132 tensor, spec = values_util.get_on_read_saveable( 1133 sync_on_read_variable, sync_on_read_variable._primary, name) 1134 1135 super(_SyncOnReadSaveable, self).__init__(tensor, spec, name) 1136 1137 def restore(self, restored_tensors, restored_shapes): 1138 """Restore the same value into all variables.""" 1139 tensor, = restored_tensors 1140 return values_util.get_on_read_restore_ops( 1141 self._sync_on_read_variable, tensor, 1142 self._sync_on_read_variable.aggregation) 1143 1144 1145class SyncOnReadVariable(DistributedVariable): 1146 """Holds a map from replica to variables whose values are reduced on save.""" 1147 1148 def _update_replica(self, update_fn, value, **kwargs): 1149 return update_fn(self._get_on_device_or_primary(), value, **kwargs) 1150 1151 # TODO(b/154017756): Make assign behaivor in cross replica context consistent 1152 # with MirroredVariable. 1153 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 1154 if values_util.is_saving_non_distributed(): 1155 return self._primary.assign_sub(value, use_locking, name, read_value) 1156 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1157 if (ds_context.in_cross_replica_context() and 1158 not values_util.in_replica_update_context()): 1159 values_util.mark_as_unsaveable() 1160 return values_util.on_read_assign_sub_cross_replica( 1161 self, value, read_value=read_value) 1162 else: 1163 return super(SyncOnReadVariable, 1164 self).assign_sub(value, use_locking, name, read_value) 1165 1166 def assign_add(self, value, use_locking=False, name=None, read_value=True): 1167 if values_util.is_saving_non_distributed(): 1168 return self._primary.assign_add(value, use_locking, name, read_value) 1169 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1170 if (ds_context.in_cross_replica_context() and 1171 not values_util.in_replica_update_context()): 1172 values_util.mark_as_unsaveable() 1173 return values_util.on_read_assign_add_cross_replica( 1174 self, value, read_value=read_value) 1175 else: 1176 return super(SyncOnReadVariable, 1177 self).assign_add(value, use_locking, name, read_value) 1178 1179 def assign(self, value, use_locking=False, name=None, read_value=True): 1180 if values_util.is_saving_non_distributed(): 1181 return self._primary.assign(value, use_locking, name, read_value) 1182 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1183 if (ds_context.in_cross_replica_context() and 1184 not values_util.in_replica_update_context()): 1185 values_util.mark_as_unsaveable() 1186 return values_util.on_read_assign_cross_replica( 1187 self, value, read_value=read_value) 1188 else: 1189 return super(SyncOnReadVariable, 1190 self).assign(value, use_locking, name, read_value) 1191 1192 def _scatter_not_implemented(self, method): 1193 raise NotImplementedError( 1194 "Variables with `synchronization=ON_READ` doesn't support `%s`" % 1195 method) 1196 1197 def scatter_sub(self, *args, **kwargs): 1198 if values_util.is_saving_non_distributed(): 1199 return self._primary.scatter_sub(*args, **kwargs) 1200 self._scatter_not_implemented("scatter_sub") 1201 1202 def scatter_add(self, *args, **kwargs): 1203 if values_util.is_saving_non_distributed(): 1204 return self._primary.scatter_add(*args, **kwargs) 1205 self._scatter_not_implemented("scatter_add") 1206 1207 def scatter_mul(self, *args, **kwargs): 1208 if values_util.is_saving_non_distributed(): 1209 return self._primary.scatter_mul(*args, **kwargs) 1210 self._scatter_not_implemented("scatter_mul") 1211 1212 def scatter_div(self, *args, **kwargs): 1213 if values_util.is_saving_non_distributed(): 1214 return self._primary.scatter_div(*args, **kwargs) 1215 self._scatter_not_implemented("scatter_div") 1216 1217 def scatter_min(self, *args, **kwargs): 1218 if values_util.is_saving_non_distributed(): 1219 return self._primary.scatter_min(*args, **kwargs) 1220 self._scatter_not_implemented("scatter_min") 1221 1222 def scatter_max(self, *args, **kwargs): 1223 if values_util.is_saving_non_distributed(): 1224 return self._primary.scatter_max(*args, **kwargs) 1225 self._scatter_not_implemented("scatter_max") 1226 1227 def scatter_update(self, *args, **kwargs): 1228 if values_util.is_saving_non_distributed(): 1229 return self._primary.scatter_update(*args, **kwargs) 1230 self._scatter_not_implemented("scatter_update") 1231 1232 def value(self): 1233 if values_util.is_saving_non_distributed(): 1234 return self._primary.value() 1235 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1236 if (ds_context.in_cross_replica_context() and 1237 not values_util.in_replica_update_context()): 1238 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1239 return self._get_replica(0).value() 1240 return self._get_cross_replica() 1241 else: 1242 # _get_on_device_or_primary() returns a Variable. 1243 return self._get_on_device_or_primary().value() 1244 1245 def _get_cross_replica(self): 1246 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1247 # Consider returning a tensor value here to make the return value of 1248 # _get_cross_replica consistent. 1249 return self._get_replica(0) 1250 if self._aggregation == vs.VariableAggregation.SUM: 1251 values_util.mark_as_unsaveable() 1252 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1253 return self._distribute_strategy.reduce( 1254 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 1255 self, 1256 axis=None) 1257 1258 def _as_graph_element(self): 1259 if values_util.is_saving_non_distributed(): 1260 return self._primary._as_graph_element() # pylint: disable=protected-access 1261 # pylint: disable=protected-access 1262 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1263 if ds_context.in_cross_replica_context(): 1264 return ops.convert_to_tensor(self._get_cross_replica()) 1265 return self._get()._as_graph_element() 1266 1267 def _gather_saveables_for_checkpoint(self): 1268 """Overrides Trackable method. 1269 1270 This allows both name-based and object-based save and restore of 1271 `SyncOnReadVariable`s. 1272 1273 Returns: 1274 A dictionary mapping attribute names to `SaveableObject` factories. 1275 """ 1276 1277 def _saveable_factory(name=self._common_name): 1278 return _SyncOnReadSaveable(self, name) 1279 1280 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1281 1282 def _write_object_proto(self, proto, options): 1283 """Update a SavedObject proto for the caller. 1284 1285 If a DistributedVariable object supports this method, it will be called when 1286 saving with a pre-built `SavedObject` proto representing the object, plus an 1287 instance of `SaveOptions`. This method is then free to modify that proto 1288 instance. 1289 1290 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1291 write out information about their components to the 1292 `experimental_distributed_variable_components` field of a 1293 `SavedVariable` (depending on the `SaveOptions` variable policy). 1294 1295 Args: 1296 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1297 will be a `SavedVariable` instance. 1298 options: A `SaveOptions` instance. 1299 """ 1300 pass 1301 1302 1303# Register a conversion functions which reads the value of the variable, 1304# allowing instances of the class to be used as tensors. 1305# DistributedVariable 1306def _tensor_conversion_distributed_var(var, dtype=None, name=None, 1307 as_ref=False): 1308 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1309 1310 1311ops.register_tensor_conversion_function(DistributedVariable, 1312 _tensor_conversion_distributed_var) 1313 1314 1315# MirroredVariables 1316def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): 1317 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1318 1319 1320ops.register_tensor_conversion_function(MirroredVariable, 1321 _tensor_conversion_mirrored) 1322 1323 1324# Mirrored Values 1325def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False): 1326 return ops.convert_to_tensor( 1327 value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1328 1329 1330ops.register_tensor_conversion_function(Mirrored, 1331 _tensor_conversion_mirrored_val) 1332 1333 1334# SyncOnReadVariables 1335def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): 1336 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1337 1338 1339ops.register_tensor_conversion_function(SyncOnReadVariable, 1340 _tensor_conversion_sync_on_read) 1341 1342 1343class VariablePolicy(object): 1344 """Policy defining synchronization and aggregation of a distributed variable. 1345 1346 Given `synchronization` and `aggregation` parameters set on a `tf.Variable` 1347 during variable creation within `tf.distribute` scope, `tf.distribute` creates 1348 an appropriate policy object and assigns it to the distributed variable. All 1349 variable operations are delegated to the respective policy object. 1350 """ 1351 1352 def __init__(self, aggregation): 1353 self._aggregation = aggregation 1354 1355 def value(self): 1356 raise NotImplementedError( 1357 "This method should be overridden by sub-classes.") 1358 1359 def _is_mirrored(self): 1360 raise NotImplementedError( 1361 "This method should be overridden by sub-classes.") 1362 1363 def _as_graph_element(self, _): 1364 raise NotImplementedError( 1365 "This method should be overridden by sub-classes.") 1366 1367 def _get_cross_replica(self, var): 1368 raise NotImplementedError( 1369 "This method should be overridden by sub-classes.") 1370 1371 def _update_replica(self, var, update_fn, value, **kwargs): 1372 raise NotImplementedError( 1373 "This method should be overridden by sub-classes.") 1374 1375 1376class OnReadPolicy(VariablePolicy): 1377 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 1378 1379 This policy is created when `synchronization` is set to 1380 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 1381 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 1382 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 1383 scope. 1384 """ 1385 1386 def _is_mirrored(self): 1387 return False 1388 1389 def value(self, var): 1390 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1391 if (ds_context.in_cross_replica_context() and 1392 not values_util.in_replica_update_context()): 1393 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1394 return var._get_replica(0).value() # pylint: disable=protected-access 1395 return var._get_cross_replica() # pylint: disable=protected-access 1396 else: 1397 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 1398 1399 def _as_graph_element(self, var): 1400 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1401 if ds_context.in_cross_replica_context(): 1402 return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access 1403 return var._get()._as_graph_element() # pylint: disable=protected-access 1404 1405 def _get_cross_replica(self, var): 1406 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1407 return var._get_replica(0) # pylint: disable=protected-access 1408 if self._aggregation == vs.VariableAggregation.SUM: 1409 values_util.mark_as_unsaveable() 1410 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1411 return var.distribute_strategy.reduce( 1412 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 1413 var, 1414 axis=None) 1415 1416 def _update_replica(self, var, update_fn, value, **kwargs): 1417 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 1418 1419 def _scatter_not_implemented(self, method): 1420 raise NotImplementedError( 1421 "ON_READ variables doesn't support `%s` in cross replica context" % 1422 method) 1423 1424 def assign_sub(self, var, value, use_locking=False, name=None, 1425 read_value=True): 1426 """Subtracts a value from this variable.""" 1427 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1428 if (ds_context.in_cross_replica_context() and 1429 not values_util.in_replica_update_context()): 1430 values_util.mark_as_unsaveable() 1431 return values_util.on_read_assign_sub_cross_replica( 1432 var, value, read_value=read_value) 1433 else: 1434 return values_util.on_write_assign_sub( 1435 var, value, use_locking=use_locking, name=name, 1436 read_value=read_value) 1437 1438 def assign_add(self, var, value, use_locking=False, name=None, 1439 read_value=True): 1440 """Adds a value to this variable.""" 1441 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1442 if (ds_context.in_cross_replica_context() and 1443 not values_util.in_replica_update_context()): 1444 values_util.mark_as_unsaveable() 1445 return values_util.on_read_assign_add_cross_replica( 1446 var, value, read_value=read_value) 1447 else: 1448 return values_util.on_write_assign_add( 1449 var, value, use_locking=use_locking, name=name, 1450 read_value=read_value) 1451 1452 def assign(self, var, value, use_locking=False, name=None, read_value=True): 1453 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1454 if (ds_context.in_cross_replica_context() and 1455 not values_util.in_replica_update_context()): 1456 values_util.mark_as_unsaveable() 1457 return values_util.on_read_assign_cross_replica(var, value, 1458 read_value=read_value) 1459 else: 1460 return values_util.on_write_assign(var, value, 1461 use_locking=use_locking, 1462 name=name, 1463 read_value=read_value) 1464 1465 def scatter_sub(self, *args, **kwargs): 1466 del args, kwargs 1467 self._scatter_not_implemented("scatter_sub") 1468 1469 def scatter_add(self, *args, **kwargs): 1470 del args, kwargs 1471 self._scatter_not_implemented("scatter_add") 1472 1473 def scatter_mul(self, *args, **kwargs): 1474 del args, kwargs 1475 self._scatter_not_implemented("scatter_mul") 1476 1477 def scatter_div(self, *args, **kwargs): 1478 del args, kwargs 1479 self._scatter_not_implemented("scatter_div") 1480 1481 def scatter_min(self, *args, **kwargs): 1482 del args, kwargs 1483 self._scatter_not_implemented("scatter_min") 1484 1485 def scatter_max(self, *args, **kwargs): 1486 del args, kwargs 1487 self._scatter_not_implemented("scatter_max") 1488 1489 def scatter_update(self, *args, **kwargs): 1490 del args, kwargs 1491 self._scatter_not_implemented("scatter_update") 1492 1493 def get_saveable(self, var, primary_var, name): 1494 """Create a saveable object for the given variable.""" 1495 return values_util.get_on_read_saveable(var, primary_var, name) 1496 1497 def get_restore_ops(self, var, tensor): 1498 """Restore the same value into all variables.""" 1499 return values_util.get_on_read_restore_ops(var, tensor, self._aggregation) 1500 1501 1502class OnWritePolicy(VariablePolicy): 1503 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 1504 1505 This policy is created when the following `synchronization` and `aggregation` 1506 parameters are specified when creating a `tf.Variable` in `tf.distribute` 1507 scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` 1508 or `tf.VariableSynchronization.AUTO`. 1509 """ 1510 1511 def _is_mirrored(self): 1512 return True 1513 1514 def value(self, var): 1515 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 1516 1517 def _as_graph_element(self, var): 1518 return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 1519 1520 def _get_cross_replica(self, var): 1521 # Return identity, to avoid directly exposing the variable to the user and 1522 # allowing it to be modified by mistake. 1523 return array_ops.identity(var._get_on_device_or_primary()) # pylint: disable=protected-access 1524 1525 def _update_replica(self, var, update_fn, value, **kwargs): 1526 if var.aggregation == variables_lib.VariableAggregation.NONE: 1527 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 1528 return _on_write_update_replica(var, update_fn, value, **kwargs) 1529 1530 def assign(self, var, value, use_locking=False, name=None, read_value=True): 1531 return values_util.on_write_assign(var, value, use_locking=use_locking, 1532 name=name, read_value=read_value) 1533 1534 def assign_add(self, var, value, use_locking=False, name=None, 1535 read_value=True): 1536 return values_util.on_write_assign_add(var, value, use_locking=use_locking, 1537 name=name, read_value=read_value) 1538 1539 def assign_sub(self, var, value, use_locking=False, name=None, 1540 read_value=True): 1541 return values_util.on_write_assign_sub(var, value, use_locking=use_locking, 1542 name=name, read_value=read_value) 1543 1544 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 1545 return values_util.scatter_sub(var, sparse_delta, use_locking=use_locking, 1546 name=name) 1547 1548 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 1549 return values_util.scatter_add(var, sparse_delta, use_locking=use_locking, 1550 name=name) 1551 1552 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 1553 return values_util.scatter_mul(var, sparse_delta, use_locking=use_locking, 1554 name=name) 1555 1556 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 1557 return values_util.scatter_div(var, sparse_delta, use_locking=use_locking, 1558 name=name) 1559 1560 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 1561 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1562 self._aggregation != vs.VariableAggregation.NONE): 1563 raise NotImplementedError(values_util.scatter_error_msg.format( 1564 op_name="scatter_min", aggregation=self._aggregation)) 1565 return values_util.scatter_min(var, sparse_delta, use_locking=use_locking, 1566 name=name) 1567 1568 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 1569 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1570 self._aggregation != vs.VariableAggregation.NONE): 1571 raise NotImplementedError(values_util.scatter_error_msg.format( 1572 op_name="scatter_max", aggregation=self._aggregation)) 1573 return values_util.scatter_max(var, sparse_delta, use_locking=use_locking, 1574 name=name) 1575 1576 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 1577 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1578 self._aggregation != vs.VariableAggregation.NONE): 1579 raise NotImplementedError(values_util.scatter_error_msg.format( 1580 op_name="scatter_update", aggregation=self._aggregation)) 1581 return values_util.scatter_update(var, sparse_delta, 1582 use_locking=use_locking, 1583 name=name) 1584 1585 def get_saveable(self, var, primary_var, name): 1586 """Saveable ops for AUTO variables.""" 1587 return values_util.get_on_write_saveable(var, primary_var, name) 1588 1589 def get_restore_ops(self, var, tensor): 1590 return values_util.get_on_write_restore_ops(var, tensor) 1591 1592 def _write_object_proto(self, var, proto, options): 1593 """Update a SavedObject proto for the caller. 1594 1595 If a DistributedVariable object supports this method, it will be called when 1596 saving with a pre-built `SavedObject` proto representing the object, plus an 1597 instance of `SaveOptions`. This method is then free to modify that proto 1598 instance. 1599 1600 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1601 write out information about their components to the 1602 `experimental_distributed_variable_components` field of a 1603 `SavedVariable` (depending on the `SaveOptions` variable policy). 1604 1605 Args: 1606 var : A DistributedVariable object 1607 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1608 will be a `SavedVariable` instance. 1609 options: A `SaveOptions` instance. 1610 """ 1611 values_util.write_object_proto(var, proto, options) 1612