1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import weakref 24 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import distribution_strategy_context 28from tensorflow.python.distribute import reduce_util 29from tensorflow.python.eager import context 30from tensorflow.python.eager import tape 31from tensorflow.python.framework import composite_tensor 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_resource_variable_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import variable_scope as vs 40from tensorflow.python.ops import variables as variables_lib 41from tensorflow.python.training import saver 42from tensorflow.python.training.tracking import base as trackable 43from tensorflow.python.util import nest 44 45 46def _get_current_replica_id_as_int(): 47 """Returns the current replica ID as an integer, or `None`.""" 48 replica_context = distribution_strategy_context.get_replica_context() 49 if replica_context: 50 replica_id = replica_context.replica_id_in_sync_group 51 if not isinstance(replica_id, int): 52 replica_id = tensor_util.constant_value(replica_id) 53 else: 54 replica_id = distribute_lib.get_update_replica_id() 55 return replica_id 56 57 58class DistributedValues(object): 59 """Holds a map from replica to values. Either PerReplica or Mirrored.""" 60 61 def __init__(self, values): 62 self._values = tuple(values) 63 64 def get(self): 65 """Returns the value for the current device or raises a ValueError.""" 66 replica_id = _get_current_replica_id_as_int() 67 if replica_id is None: 68 return self._get_cross_replica() 69 else: 70 return self._values[replica_id] 71 72 def _get_cross_replica(self): 73 raise NotImplementedError( 74 "This method should be overridden by sub-classes which support cross-" 75 "replica accesses.") 76 77 def _get_closest(self): 78 """Returns value in same replica or device if possible, else the primary.""" 79 replica_id = _get_current_replica_id_as_int() 80 if replica_id is None: 81 # Try to find a value on the current device. 82 current_device = device_util.canonicalize(device_util.current()) 83 for value in self._values: 84 if device_util.canonicalize(value.device) == current_device: 85 return value 86 return self.primary 87 else: 88 return self._values[replica_id] 89 90 @property 91 def primary(self): 92 """Returns a representative component.""" 93 return self._values[0] 94 95 # TODO(josh11b): Replace experimental_local_results with this? 96 @property 97 def values(self): 98 return self._values 99 100 @property 101 def devices(self): 102 return tuple(v.device for v in self._values) 103 104 @property 105 def is_tensor_like(self): 106 return all(tensor_util.is_tensor(v) for v in self._values) 107 108 def __str__(self): 109 debug_str = ",\n".join( 110 " %d: %s" % (i, v) for i, v in enumerate(self._values)) 111 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 112 113 def __repr__(self): 114 debug_repr = ",\n".join( 115 " %d: %r" % (i, v) for i, v in enumerate(self._values)) 116 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 117 118 119# NOTE(josh11b,apassos): It would be great if we could inspect the values this was 120# initialized with and use that to generate the overloaded operators here. 121# Unfortunately, Python's rules for special methods don't allow this, see 122# https://docs.python.org/3/reference/datamodel.html#special-method-names 123# "if a class defines a method named __getitem__(), and x is an instance of 124# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)." 125# In particular, these special methods don't go through __getattr__, and 126# it will only use those methods if they are defined in the class, not the 127# object. 128class DistributedDelegate(DistributedValues): 129 """A map from device to values; acts as the same type as the values.""" 130 131 def __getattr__(self, name): 132 # The '_use_resource_variables' and the attrs starts with '_self' are used 133 # for restoring the saved_model proto, and '_attribute_sentinel' is used for 134 # Layer tracking. At the point these attrs are queried, the variable has not 135 # been initialized. Thus it should not query those of the underlying 136 # components. 137 if name.startswith("_self_") or name in ("_use_resource_variables", 138 "_attribute_sentinel", 139 "_distributed_container"): 140 return super(DistributedDelegate, self).__getattr__(name) 141 142 # TODO(priyag): This needs to be made robust against pitfalls from mix use 143 # __getattr__ and @property. See b/120402273. 144 return getattr(self.get(), name) 145 146 def _get_as_operand(self): 147 """Returns the value for operations for the current device. 148 149 Some implementations, e.g. `TPUMirroredVariable`, are not able to return the 150 value type within a replica context. They can, however, return a value that 151 can be used by the operations below. 152 """ 153 return self.get() 154 155 # pylint: disable=multiple-statements 156 def __add__(self, o): 157 return self._get_as_operand() + o 158 159 def __radd__(self, o): 160 return o + self._get_as_operand() 161 162 def __sub__(self, o): 163 return self._get_as_operand() - o 164 165 def __rsub__(self, o): 166 return o - self._get_as_operand() 167 168 def __mul__(self, o): 169 return self._get_as_operand() * o 170 171 def __rmul__(self, o): 172 return o * self._get_as_operand() 173 174 def __truediv__(self, o): 175 return self._get_as_operand() / o 176 177 def __rtruediv__(self, o): 178 return o / self._get_as_operand() 179 180 def __floordiv__(self, o): 181 return self._get_as_operand() // o 182 183 def __rfloordiv__(self, o): 184 return o // self._get_as_operand() 185 186 def __mod__(self, o): 187 return self._get_as_operand() % o 188 189 def __rmod__(self, o): 190 return o % self._get_as_operand() 191 192 def __lt__(self, o): 193 return self._get_as_operand() < o 194 195 def __le__(self, o): 196 return self._get_as_operand() <= o 197 198 def __gt__(self, o): 199 return self._get_as_operand() > o 200 201 def __ge__(self, o): 202 return self._get_as_operand() >= o 203 204 def __and__(self, o): 205 return self._get_as_operand() & o 206 207 def __rand__(self, o): 208 return o & self._get_as_operand() 209 210 def __or__(self, o): 211 return self._get_as_operand() | o 212 213 def __ror__(self, o): 214 return o | self._get_as_operand() 215 216 def __xor__(self, o): 217 return self._get_as_operand() ^ o 218 219 def __rxor__(self, o): 220 return o ^ self._get_as_operand() 221 222 def __getitem__(self, o): 223 return self._get_as_operand()[o] 224 225 def __pow__(self, o, modulo=None): 226 return pow(self._get_as_operand(), o, modulo) 227 228 def __rpow__(self, o): 229 return pow(o, self._get_as_operand()) 230 231 def __invert__(self): 232 return ~self._get_as_operand() 233 234 def __neg__(self): 235 return -self._get_as_operand() 236 237 def __abs__(self): 238 return abs(self._get_as_operand()) 239 240 def __div__(self, o): 241 try: 242 return self._get_as_operand().__div__(o) 243 except AttributeError: 244 # See https://docs.python.org/3/library/constants.html#NotImplemented 245 return NotImplemented 246 247 def __rdiv__(self, o): 248 try: 249 return self._get_as_operand().__rdiv__(o) 250 except AttributeError: 251 # See https://docs.python.org/3/library/constants.html#NotImplemented 252 return NotImplemented 253 254 def __matmul__(self, o): 255 try: 256 return self._get_as_operand().__matmul__(o) 257 except AttributeError: 258 # See https://docs.python.org/3/library/constants.html#NotImplemented 259 return NotImplemented 260 261 def __rmatmul__(self, o): 262 try: 263 return self._get_as_operand().__rmatmul__(o) 264 except AttributeError: 265 # See https://docs.python.org/3/library/constants.html#NotImplemented 266 return NotImplemented 267 268 # TODO(josh11b): Even more operator overloads. 269 270 271class PerReplica(DistributedValues, composite_tensor.CompositeTensor): 272 """Holds a map from replica to unsynchronized values.""" 273 274 @property 275 def _type_spec(self): 276 return PerReplicaSpec( 277 *(type_spec.type_spec_from_value(v) for v in self._values)) 278 279 280class PerReplicaSpec(type_spec.TypeSpec): 281 """Type specification for a `PerReplica`.""" 282 283 __slots__ = ["_value_specs"] 284 285 value_type = property(lambda self: PerReplica) 286 287 def __init__(self, *value_specs): 288 self._value_specs = tuple(value_specs) 289 290 def _serialize(self): 291 return self._value_specs 292 293 @property 294 def _component_specs(self): 295 return self._value_specs 296 297 def _to_components(self, value): 298 replica_context = distribution_strategy_context.get_replica_context() 299 if replica_context is not None and replica_context.num_replicas_in_sync > 1: 300 raise ValueError( 301 "Flattening a PerReplica to components is not supported in replica " 302 "context.") 303 return value._values # pylint: disable=protected-access 304 305 def _from_components(self, tensor_list): 306 return PerReplica(tensor_list) 307 308 309# Note that unlike PerReplica, Mirrored values inherit from 310# DistributedDelegate and so can be used directly in cross-replica mode. 311# TODO(tomhennigan) Should this extend CompositeTensor? 312class Mirrored(DistributedDelegate): 313 """Holds a map from replica to values which are kept in sync.""" 314 315 def _get_cross_replica(self): 316 return self._get_closest() 317 318 def _as_graph_element(self): 319 obj = self.get() 320 conv_fn = getattr(obj, "_as_graph_element", None) 321 if conv_fn and callable(conv_fn): 322 return conv_fn() 323 return obj 324 325 326def _assign_on_device(device, variable, tensor): 327 with ops.device(device): 328 return variable.assign(tensor) 329 330 331def _assign_add_on_device(device, variable, tensor): 332 with ops.device(device): 333 return variable.assign_add(tensor) 334 335 336def _assign_sub_on_device(device, variable, tensor): 337 with ops.device(device): 338 return variable.assign_sub(tensor) 339 340 341def _assert_strategy(strategy): 342 if not distribution_strategy_context.has_strategy(): 343 raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % 344 (strategy,)) 345 current_strategy = distribution_strategy_context.get_strategy() 346 if current_strategy is not strategy: 347 raise RuntimeError( 348 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 349 (current_strategy, strategy)) 350 351 352@contextlib.contextmanager 353def _enter_or_assert_strategy(strategy): 354 if not distribution_strategy_context.has_strategy(): 355 with strategy.scope(): 356 yield 357 else: 358 _assert_strategy(strategy) 359 yield 360 361 362DistributedVarOp = collections.namedtuple( 363 "DistributedVarOp", ["name", "graph", "traceback", "type"]) 364 365 366class DistributedVariable(DistributedDelegate, variables_lib.Variable): 367 """Holds a map from replica to variables.""" 368 369 # TODO(josh11b): Support changing the set of variables if e.g. if new 370 # devices are joining or a device is to leave. 371 372 def __init__(self, strategy, values): 373 self._distribute_strategy = strategy 374 super(DistributedVariable, self).__init__(values) 375 self._common_name = self.primary.name.split(":")[0] 376 # Use a weakref to make it easy to map from the contained values 377 # to the container without introducing a reference cycle. 378 for v in values: 379 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 380 # tf.keras keeps track of variables initialized using this attribute. When 381 # tf.keras gets the default session, it initializes all uninitialized vars. 382 # We need to make _keras_initialized a member of DistributedVariable because 383 # without this it will use `__getattr__` which will delegate to a component 384 # variable. 385 self._keras_initialized = False 386 # Typically, a `DistributedVariable`'s initializer is composed of the 387 # initializers of the components variables. However, in some cases, such as 388 # when restoring from a checkpoint, we may set the _initializer_op 389 # property on the entire `DistributedVariable`. 390 self._initializer_op = None 391 392 def is_initialized(self, name=None): 393 """Identifies if all the component variables are initialized. 394 395 Args: 396 name: Name of the final `logical_and` op. 397 398 Returns: 399 The op that evaluates to True or False depending on if all the 400 component variables are initialized. 401 """ 402 result = self.primary.is_initialized() 403 # We iterate through the list of values except the last one to allow us to 404 # name the final `logical_and` op the same name that is passed by the user 405 # to the `is_initialized` op. For distributed variables, the 406 # `is_initialized` op is a `logical_and` op. 407 for v in self._values[1:-1]: 408 result = math_ops.logical_and(result, v.is_initialized()) 409 result = math_ops.logical_and( 410 result, self._values[-1].is_initialized(), name=name) 411 return result 412 413 @property 414 def initializer(self): 415 if self._initializer_op: 416 init_op = self._initializer_op 417 else: 418 # return grouped ops of all the var initializations of component values of 419 # the mirrored variable 420 init_op = control_flow_ops.group( 421 tuple(v.initializer for v in self._values)) 422 return init_op 423 424 def initialized_value(self): 425 return self._get_closest().initialized_value() 426 427 @property 428 def initial_value(self): 429 return self._get_closest().initial_value 430 431 @property 432 def graph(self): 433 return self.primary.graph 434 435 @property 436 def _shared_name(self): 437 return self._common_name 438 439 @property 440 def _unique_id(self): 441 return self.primary._unique_id # pylint: disable=protected-access 442 443 @property 444 def _graph_key(self): 445 """Lets Optimizers know which graph this variable is from.""" 446 return self.primary._graph_key # pylint: disable=protected-access 447 448 @property 449 def name(self): 450 return self.primary.name 451 452 @property 453 def dtype(self): 454 return self.primary.dtype 455 456 @property 457 def shape(self): 458 return self.primary.shape 459 460 @property 461 def synchronization(self): 462 return self.primary.synchronization 463 464 @property 465 def handle(self): 466 replica_id = _get_current_replica_id_as_int() 467 if replica_id is None: 468 raise ValueError("`handle` is not available outside the replica context" 469 " or a `tf.distribute.Strategy.update()` call.") 470 else: 471 return self._values[replica_id].handle 472 473 def eval(self, session=None): 474 return self._get_closest().eval(session) 475 476 @property 477 def _save_slice_info(self): 478 return self.primary._save_slice_info # pylint: disable=protected-access 479 480 def _get_save_slice_info(self): 481 return self.primary._get_save_slice_info() # pylint: disable=protected-access 482 483 def _set_save_slice_info(self, save_slice_info): 484 for v in self._values: 485 v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access 486 487 @property 488 def device(self): 489 return self._get_closest().device 490 491 @property 492 def trainable(self): 493 return self.primary.trainable 494 495 @property 496 def distribute_strategy(self): 497 return self._distribute_strategy 498 499 def get_shape(self): 500 return self.primary.get_shape() 501 502 def to_proto(self, export_scope=None): 503 return self.primary.to_proto(export_scope=export_scope) 504 505 @property 506 def op(self): 507 # We want cross-replica code that does some var.op.X calls 508 # to work (even if the current device isn't in self.devices), but 509 # other uses of var.op in a cross-replica context to fail. 510 if distribution_strategy_context.in_cross_replica_context(): 511 return DistributedVarOp(self.primary.op.name, self.primary.op.graph, 512 self.primary.op.traceback, self.primary.op.type) 513 return self.get().op 514 515 @property 516 def _in_graph_mode(self): 517 return self.primary._in_graph_mode # pylint: disable=protected-access 518 519 def read_value(self): 520 with _enter_or_assert_strategy(self._distribute_strategy): 521 return array_ops.identity(self.get()) 522 523 def value(self): 524 return self._get_closest().value() 525 526 def _should_act_as_resource_variable(self): 527 """Pass resource_variable_ops.is_resource_variable check.""" 528 pass 529 530 531ops.register_dense_tensor_like_type(DistributedVariable) 532 533 534@contextlib.contextmanager 535def _maybe_enter_graph(tensor): 536 # Note: might have an eager tensor but not be executing eagerly when 537 # building functions. 538 if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or 539 ops.has_default_graph()): 540 yield 541 else: 542 with tensor.graph.as_default(): 543 yield 544 545 546def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring 547 548 def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring 549 del use_locking # Unused. 550 551 with _maybe_enter_graph(var.handle): 552 op = raw_assign_fn( 553 var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) 554 555 with ops.control_dependencies([op]): 556 return var._read_variable_op() if read_value else op # pylint: disable=protected-access 557 558 return assign_fn 559 560 561class TPUVariableMixin(object): 562 """Mixin for TPU variables.""" 563 564 def __init__(self, *args, **kwargs): 565 super(TPUVariableMixin, self).__init__(*args, **kwargs) 566 567 # Handle ID is needed for `get_replicated_var_handle` to cache the variables 568 # correctly since in eager mode different variables can have the same name. 569 if ops.executing_eagerly_outside_functions(): 570 self._handle_id = self._common_name + "_" + str(id(self.primary)) 571 else: 572 self._handle_id = self._common_name 573 574 def __getattr__(self, name): 575 if _enclosing_tpu_context() is None: 576 return super(TPUVariableMixin, self).__getattr__(name) 577 else: 578 raise AttributeError( 579 "'{}' not accessible within a TPU context.".format(name)) 580 581 def get(self): 582 if _enclosing_tpu_context() is None: 583 return super(TPUVariableMixin, self).get() 584 else: 585 raise NotImplementedError( 586 "`TPUVariableMixin.get()` is not supported within a TPU context.") 587 588 def _get_as_operand(self): 589 return self.read_value() 590 591 def _get_closest(self): 592 if _enclosing_tpu_context() is None: 593 return super(TPUVariableMixin, self)._get_closest() 594 else: 595 return self.primary 596 597 def numpy(self): 598 if context.executing_eagerly(): 599 return self.read_value().numpy() 600 else: 601 raise NotImplementedError( 602 "numpy() is only available when eager execution is enabled.") 603 604 def _is_mirrored(self): 605 raise NotImplementedError( 606 "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") 607 608 @property 609 def handle(self): 610 # If we're in a tpu.rewrite(), return the replicated handle. 611 tpu_context = _enclosing_tpu_context() 612 if tpu_context is None: 613 return self._get_closest().handle 614 else: 615 return tpu_context.get_replicated_var_handle( 616 self._handle_id, self._values, self._is_mirrored()) 617 618 @property 619 def device(self): 620 return self.handle.device 621 622 def _read_variable_op(self): 623 if self.trainable: 624 tape.variable_accessed(self) 625 return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype) 626 627 def read_value(self): 628 if _enclosing_tpu_context() is None: 629 return super(TPUVariableMixin, self).read_value() 630 else: 631 return self._read_variable_op() 632 633 @property 634 def constraint(self): 635 return self.primary.constraint 636 637 def _as_graph_element(self): 638 if _enclosing_tpu_context() is None: 639 return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access 640 else: 641 return None 642 643 @property 644 def op(self): 645 return DistributedVarOp(self.primary.op.name, self.primary.op.graph, 646 self.primary.op.traceback, self.primary.op.type) 647 648 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 649 """Converts a variable to a tensor.""" 650 # pylint: disable=protected-access 651 if _enclosing_tpu_context() is None: 652 return super(TPUVariableMixin, self)._dense_var_to_tensor( 653 dtype=dtype, name=name, as_ref=as_ref) 654 # pylint: enable=protected-access 655 elif dtype is not None and dtype != self.dtype: 656 return math_ops.cast(self.read_value(), dtype) 657 else: 658 return self.handle if as_ref else self.read_value() 659 660 661def _validate_colocate_extended(v, extended): 662 variable_strategy = v._distribute_strategy # pylint: disable=protected-access 663 if variable_strategy.extended is not extended: 664 raise ValueError( 665 "`colocate_vars_with` must only be passed a variable created in this " 666 "tf.distribute.Strategy.scope(), not %s created in scope: %s" % 667 (v, variable_strategy)) 668 669 670def validate_colocate_distributed_variable(v, extended): 671 if not isinstance(v, DistributedVariable): 672 raise ValueError( 673 "`colocate_vars_with` must only be passed a variable created in this " 674 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 675 _validate_colocate_extended(v, extended) 676 677 678def validate_colocate(v, extended): 679 if not hasattr(v, "_distribute_strategy"): 680 raise ValueError( 681 "`colocate_vars_with` must only be passed a variable created in this " 682 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 683 _validate_colocate_extended(v, extended) 684 685 686def _apply_aggregation(strategy, value, aggregation, destinations): 687 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 688 return strategy.extended.broadcast_to( 689 strategy.experimental_local_results(value)[0], 690 destinations=destinations) 691 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 692 return strategy.extended.reduce_to(reduce_op, value, destinations) 693 694 695_aggregation_error_msg = ( 696 "You must specify an aggregation method to update a " 697 "{variable_type} in Replica Context. You can do so by passing " 698 "an explicit value for argument `aggregation` to tf.Variable(..)." 699 "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`" 700 "`tf.VariableAggregation` lists the possible aggregation methods." 701 "This is required because {variable_type} should always be " 702 "kept in sync. When updating them or assigning to them in a " 703 "replica context, we automatically try to aggregate the values " 704 "before updating the variable. For this aggregation, we need to " 705 "know the aggregation method. " 706 "Another alternative is to not try to update such " 707 "{variable_type} in replica context, but in cross replica " 708 "context. You can enter cross replica context by calling " 709 "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." 710 "Inside `merge_fn`, you can then update the {variable_type} " 711 "using `tf.distribute.StrategyExtended.update()`.") 712 713 714class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): 715 """Class for defining how to restore a MirroredVariable.""" 716 717 def __init__(self, mirrored_variable, primary_variable, name): 718 self._mirrored_variable = mirrored_variable 719 super(_MirroredSaveable, self).__init__(primary_variable, "", name) 720 721 def restore(self, restored_tensors, restored_shapes): 722 """Restore the same value into all variables.""" 723 tensor, = restored_tensors 724 return control_flow_ops.group( 725 tuple( 726 _assign_on_device(v.device, v, tensor) 727 for v in self._mirrored_variable.values)) 728 729 730def create_mirrored_variable( # pylint: disable=missing-docstring 731 strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs): 732 # Figure out what collections this variable should be added to. 733 # We'll add the MirroredVariable to those collections instead. 734 var_collections = kwargs.pop("collections", None) 735 if var_collections is None: 736 var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] 737 kwargs["collections"] = [] 738 739 synchronization = kwargs.get("synchronization", 740 vs.VariableSynchronization.ON_WRITE) 741 742 if synchronization == vs.VariableSynchronization.NONE: 743 raise ValueError( 744 "`NONE` variable synchronization mode is not supported with `Mirrored` " 745 "distribution strategy. Please change the `synchronization` for " 746 "variable: " + str(kwargs["name"])) 747 elif synchronization == vs.VariableSynchronization.ON_READ: 748 is_sync_on_read = True 749 elif synchronization in (vs.VariableSynchronization.ON_WRITE, 750 vs.VariableSynchronization.AUTO): 751 # `AUTO` synchronization defaults to `ON_WRITE`. 752 is_sync_on_read = False 753 else: 754 raise ValueError( 755 "Invalid variable synchronization mode: %s for variable: %s" % 756 (synchronization, kwargs["name"])) 757 758 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 759 760 if aggregation not in (vs.VariableAggregation.NONE, 761 vs.VariableAggregation.SUM, 762 vs.VariableAggregation.MEAN, 763 vs.VariableAggregation.ONLY_FIRST_REPLICA): 764 raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % 765 (aggregation, kwargs["name"])) 766 767 # Ignore user-specified caching device, not needed for mirrored variables. 768 kwargs.pop("caching_device", None) 769 770 # TODO(josh11b,apassos): It would be better if variable initialization 771 # was never recorded on the tape instead of having to do this manually 772 # here. 773 with tape.stop_recording(): 774 value_list = real_mirrored_creator(**kwargs) 775 var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls 776 result = var_cls(strategy, value_list, aggregation) 777 778 # Add the wrapped variable to the requested collections. 779 # The handling of eager mode and the global step matches 780 # ResourceVariable._init_from_args(). 781 if not context.executing_eagerly(): 782 g = ops.get_default_graph() 783 # If "trainable" is True, next_creator() will add the member variables 784 # to the TRAINABLE_VARIABLES collection, so we manually remove 785 # them and replace with the MirroredVariable. We can't set 786 # "trainable" to False for next_creator() since that causes functions 787 # like implicit_gradients to skip those variables. 788 if kwargs.get("trainable", True): 789 var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 790 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 791 for value in value_list: 792 for i, trainable_variable in enumerate(l): 793 if value is trainable_variable: 794 del l[i] 795 break 796 797 g.add_to_collections(var_collections, result) 798 elif ops.GraphKeys.GLOBAL_STEP in var_collections: 799 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) 800 801 return result 802 803 804class MirroredVariable(DistributedVariable, Mirrored): 805 """Holds a map from replica to variables whose values are kept in sync.""" 806 807 def __init__(self, strategy, values, aggregation): 808 super(MirroredVariable, self).__init__(strategy, values) 809 self._aggregation = aggregation 810 811 # The arguments to update() are automatically unwrapped so the update() 812 # function would normally see regular variables, not MirroredVariables. 813 # However, the update function can still operate on wrapped MirroredVariables 814 # through object members, captured arguments, etc. This is more likely in an 815 # update_non_slot() function (like OptimizerV2._finish), which can 816 # update several non-slot variables in one call. 817 def _assign_func(self, *args, **kwargs): 818 with _enter_or_assert_strategy(self._distribute_strategy): 819 f = kwargs.pop("f") 820 if distribution_strategy_context.in_cross_replica_context(): 821 update_replica_id = distribute_lib.get_update_replica_id() 822 if update_replica_id is not None: 823 # We are calling an assign function on the mirrored variable in an 824 # update context. 825 return f(self.values[update_replica_id], *args, **kwargs) 826 827 # We are calling assign on the mirrored variable in cross replica 828 # context, use `strategy.extended.update()` to update the variable. 829 return self._distribute_strategy.extended.update( 830 self, f, args=args, kwargs=kwargs) 831 else: 832 _assert_replica_context(self._distribute_strategy) 833 # We are calling an assign function on the mirrored variable in replica 834 # context. 835 # We reduce the value we want to assign/add/sub. More details about how 836 # we handle the different use cases can be found in the _reduce method. 837 # We call the function on each of the mirrored variables with the 838 # reduced value. 839 if self._aggregation == vs.VariableAggregation.NONE: 840 raise ValueError( 841 _aggregation_error_msg.format(variable_type="MirroredVariable")) 842 843 def merge_fn(strategy, value, *other_args, **other_kwargs): # pylint: disable=missing-docstring 844 # Don't allow MEAN with non float dtype, since it may cause unexpected 845 # precision loss. Python3 and NumPy automatically upcast integers to 846 # float in division, but we should always preserve the type. 847 # 848 # Note that to be backward compatible we allow the case when the value 849 # is *always* the same on each replica. I.E. value is not a 850 # PerReplica. Refer to regroup() to see how values are grouped. 851 if self._aggregation == vs.VariableAggregation.MEAN and ( 852 not self.dtype.is_floating) and isinstance(value, PerReplica): 853 raise ValueError( 854 "Cannot update non-float variables with " 855 "tf.VariableAggregation.MEAN aggregation in replica context. " 856 "Either change the variable dtype to float or update it in " 857 "cross-replica context.") 858 859 v = _apply_aggregation(strategy, value, self._aggregation, self) 860 return strategy.extended.update( 861 self, f, args=(v,) + other_args, kwargs=other_kwargs) 862 863 return distribution_strategy_context.get_replica_context().merge_call( 864 merge_fn, args=args, kwargs=kwargs) 865 866 def assign_sub(self, *args, **kwargs): 867 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 868 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 869 870 def assign_add(self, *args, **kwargs): 871 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 872 return self._assign_func(f=assign_add_fn, *args, **kwargs) 873 874 def assign(self, *args, **kwargs): 875 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 876 return self._assign_func(f=assign_fn, *args, **kwargs) 877 878 @property 879 def aggregation(self): 880 return self._aggregation 881 882 def _get_cross_replica(self): 883 # Return identity, to avoid directly exposing the variable to the user and 884 # allowing it to be modified by mistake. 885 return array_ops.identity(Mirrored._get_cross_replica(self)) 886 887 def _as_graph_element(self): 888 return self._get_closest()._as_graph_element() # pylint: disable=protected-access 889 890 def _gather_saveables_for_checkpoint(self): 891 """Overrides Trackable method. 892 893 This allows both name-based and object-based save and restore of 894 MirroredVariables. 895 896 Returns: 897 A dictionary mapping attribute names to `SaveableObject` factories. 898 """ 899 900 def _saveable_factory(name=self._common_name): 901 return _MirroredSaveable(self, self.primary, name) 902 903 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 904 905 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 906 """Converts a variable to a tensor.""" 907 # Try to avoid assignments to and other mutations of MirroredVariable 908 # state except through a DistributionStrategy.extended.update() call. 909 assert not as_ref 910 return ops.convert_to_tensor( 911 self.get(), dtype=dtype, name=name, as_ref=as_ref) 912 913 914# Register a conversion function which reads the value of the variable, 915# allowing instances of the class to be used as tensors. 916def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): 917 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 918 919 920ops.register_tensor_conversion_function(MirroredVariable, 921 _tensor_conversion_mirrored) 922 923 924def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False): 925 return ops.convert_to_tensor( 926 value.get(), dtype=dtype, name=name, as_ref=as_ref) 927 928 929ops.register_tensor_conversion_function(Mirrored, 930 _tensor_conversion_mirrored_val) 931 932 933def _enclosing_tpu_context(): 934 """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" 935 graph = ops.get_default_graph() 936 while graph is not None: 937 # pylint: disable=protected-access 938 context_ = graph._get_control_flow_context() 939 # pylint: enable=protected-access 940 while context_ is not None: 941 if isinstance(context_, control_flow_ops.XLAControlFlowContext): 942 return context_ 943 context_ = context_.outer_context 944 # This may be a FuncGraph due to defuns or v2 control flow. We need to 945 # find the original graph with the XLAControlFlowContext. 946 graph = getattr(graph, "outer_graph", None) 947 return None 948 949 950def is_distributed_variable(v): 951 """Determine if a variable is ds variable or TPU mirrored variable.""" 952 return isinstance(v, DistributedVariable) 953 954 955class TPUMirroredVariable(TPUVariableMixin, MirroredVariable): 956 """Holds a map from replica to TPU variables whose values are kept in sync.""" 957 958 def _assign_func(self, *args, **kwargs): 959 with _enter_or_assert_strategy(self._distribute_strategy): 960 if (distribution_strategy_context.in_cross_replica_context() and 961 (_enclosing_tpu_context() is not None)): 962 f = kwargs.pop("f") 963 return self._distribute_strategy.extended.update( 964 self, f, args=args, kwargs=kwargs) 965 else: 966 return MirroredVariable._assign_func(self, *args, **kwargs) 967 968 def assign_sub(self, *args, **kwargs): 969 assign_sub_fn = _make_raw_assign_fn( 970 gen_resource_variable_ops.assign_sub_variable_op) 971 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 972 973 def assign_add(self, *args, **kwargs): 974 assign_add_fn = _make_raw_assign_fn( 975 gen_resource_variable_ops.assign_add_variable_op) 976 return self._assign_func(f=assign_add_fn, *args, **kwargs) 977 978 def assign(self, *args, **kwargs): 979 assign_fn = _make_raw_assign_fn( 980 gen_resource_variable_ops.assign_variable_op) 981 return self._assign_func(f=assign_fn, *args, **kwargs) 982 983 def _is_mirrored(self): 984 return True 985 986 987class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject): 988 """Class for defining how to restore a SyncOnReadVariable.""" 989 990 def __init__(self, sync_on_read_variable, name): 991 self._sync_on_read_variable = sync_on_read_variable 992 993 # We use a callable so that we don't have to evaluate this expression 994 # in the case where we are trying to restore instead of save. 995 def tensor(): 996 strategy = sync_on_read_variable._distribute_strategy # pylint: disable=protected-access 997 return strategy.extended.read_var(sync_on_read_variable) 998 999 spec = saver.BaseSaverBuilder.SaveSpec( 1000 tensor=tensor, 1001 slice_spec="", 1002 name=name, 1003 dtype=sync_on_read_variable.dtype, 1004 device=sync_on_read_variable.primary.device) 1005 super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name) 1006 1007 def restore(self, restored_tensors, restored_shapes): 1008 """Restore the same value into all variables.""" 1009 # To preserve the sum across save and restore, we have to divide the 1010 # total across all devices when restoring a variable that was summed 1011 # when saving. 1012 tensor, = restored_tensors 1013 if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM: 1014 tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices), 1015 self._sync_on_read_variable.dtype) 1016 return control_flow_ops.group( 1017 tuple( 1018 _assign_on_device(v.device, v, tensor) 1019 for v in self._sync_on_read_variable.values)) 1020 1021 1022def _assert_replica_context(strategy): 1023 replica_context = distribution_strategy_context.get_replica_context() 1024 if not replica_context: 1025 raise RuntimeError( 1026 "Replica-local variables may only be assigned in a replica context.") 1027 if replica_context.strategy is not strategy: 1028 raise RuntimeError( 1029 "Replica-local variables may only be assigned in a replica context.") 1030 1031 1032class SyncOnReadVariable(DistributedVariable): 1033 """Holds a map from replica to variables whose values are reduced on save.""" 1034 1035 def __init__(self, strategy, values, aggregation): 1036 super(SyncOnReadVariable, self).__init__(strategy, values) 1037 self._aggregation = aggregation 1038 1039 def assign_sub(self, *args, **kwargs): 1040 with _enter_or_assert_strategy(self._distribute_strategy): 1041 if distribution_strategy_context.in_cross_replica_context(): 1042 if self._aggregation == vs.VariableAggregation.SUM: 1043 raise ValueError( 1044 "SyncOnReadVariable does not support `assign_sub` in " 1045 "cross-replica context when aggregation is set to " 1046 "`tf.VariableAggregation.SUM`.") 1047 return control_flow_ops.group( 1048 tuple( 1049 _assign_sub_on_device(v.device, v, args[0]) 1050 for v in self._values)) 1051 else: 1052 return self.get().assign_sub(*args, **kwargs) 1053 1054 def assign_add(self, *args, **kwargs): 1055 with _enter_or_assert_strategy(self._distribute_strategy): 1056 if distribution_strategy_context.in_cross_replica_context(): 1057 if self._aggregation == vs.VariableAggregation.SUM: 1058 raise ValueError( 1059 "SyncOnReadVariable does not support `assign_add` in " 1060 "cross-replica context when aggregation is set to " 1061 "`tf.VariableAggregation.SUM`.") 1062 return control_flow_ops.group( 1063 tuple( 1064 _assign_add_on_device(v.device, v, args[0]) 1065 for v in self._values)) 1066 else: 1067 return self.get().assign_add(*args, **kwargs) 1068 1069 def assign(self, *args, **kwargs): 1070 with _enter_or_assert_strategy(self._distribute_strategy): 1071 if distribution_strategy_context.in_cross_replica_context(): 1072 # To preserve the sum across save and restore, we have to divide the 1073 # total across all devices when restoring a variable that was summed 1074 # when saving. 1075 tensor = args[0] 1076 if self._aggregation == vs.VariableAggregation.SUM: 1077 tensor = math_ops.cast(tensor / len(self._values), self.dtype) 1078 return control_flow_ops.group( 1079 tuple(_assign_on_device(v.device, v, tensor) for v in self._values)) 1080 else: 1081 return self.get().assign(*args, **kwargs) 1082 1083 @property 1084 def aggregation(self): 1085 return self._aggregation 1086 1087 def _get_cross_replica(self): 1088 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1089 return self.primary 1090 1091 with _enter_or_assert_strategy(self._distribute_strategy): 1092 return self._distribute_strategy.reduce( 1093 reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), 1094 self, 1095 axis=None) 1096 1097 def _as_graph_element(self): 1098 # pylint: disable=protected-access 1099 if distribution_strategy_context.in_cross_replica_context(): 1100 return self._get_cross_replica() 1101 return self.get()._as_graph_element() 1102 1103 def _gather_saveables_for_checkpoint(self): 1104 """Overrides Trackable method. 1105 1106 This allows both name-based and object-based save and restore of 1107 `SyncOnReadVariable`s. 1108 1109 Returns: 1110 A dictionary mapping attribute names to `SaveableObject` factories. 1111 """ 1112 1113 def _saveable_factory(name=self._common_name): 1114 return _SyncOnReadSaveable(self, name) 1115 1116 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1117 1118 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1119 """Converts a variable to a tensor.""" 1120 return ops.convert_to_tensor( 1121 self.get(), dtype=dtype, name=name, as_ref=as_ref) 1122 1123 1124# Register a conversion function for SyncOnReadVariable which allows as_ref to 1125# be true. 1126def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): 1127 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1128 1129 1130ops.register_tensor_conversion_function(SyncOnReadVariable, 1131 _tensor_conversion_sync_on_read) 1132 1133 1134class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable): 1135 """Holds a map from replica to variables whose values are reduced on save.""" 1136 1137 def assign_sub(self, *args, **kwargs): 1138 if _enclosing_tpu_context() is None: 1139 return SyncOnReadVariable.assign_sub(self, *args, **kwargs) 1140 else: 1141 return _make_raw_assign_fn( 1142 gen_resource_variable_ops.assign_sub_variable_op)(self, *args, 1143 **kwargs) 1144 1145 def assign_add(self, *args, **kwargs): 1146 if _enclosing_tpu_context() is None: 1147 return SyncOnReadVariable.assign_add(self, *args, **kwargs) 1148 else: 1149 return _make_raw_assign_fn( 1150 gen_resource_variable_ops.assign_add_variable_op)(self, *args, 1151 **kwargs) 1152 1153 def assign(self, *args, **kwargs): 1154 if _enclosing_tpu_context() is None: 1155 return SyncOnReadVariable.assign(self, *args, **kwargs) 1156 else: 1157 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 1158 self, *args, **kwargs) 1159 1160 def _is_mirrored(self): 1161 return False 1162 1163 1164def regroup(values, wrap_class=PerReplica): 1165 """Makes a nest per-replica into a nest of PerReplica/Mirrored values.""" 1166 v0 = values[0] 1167 1168 if isinstance(v0, list): 1169 for v in values[1:]: 1170 assert isinstance(v, list) 1171 assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % 1172 (len(v), len(v0), v, v0)) 1173 return [ 1174 regroup(tuple(v[i] for v in values), wrap_class) 1175 for i in range(len(v0)) 1176 ] 1177 1178 if isinstance(v0, tuple): 1179 for v in values[1:]: 1180 assert isinstance(v, tuple) 1181 assert len(v) == len(v0) 1182 regrouped_tuple = tuple( 1183 regroup(tuple(v[i] for v in values), wrap_class) 1184 for i in range(len(v0))) 1185 if hasattr(v0, "_fields"): 1186 # This tuple is in fact a namedtuple! Create a new namedtuple instance 1187 # and initialize it with the regrouped values: 1188 assert hasattr(type(v0), "_make") 1189 return type(v0)._make(regrouped_tuple) 1190 else: 1191 return regrouped_tuple 1192 1193 if isinstance(v0, dict): 1194 v0keys = set(v0.keys()) 1195 for v in values[1:]: 1196 assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v)) 1197 assert set(v.keys()) == v0keys, ("v[0].keys: %s v[i].keys: %s" % 1198 (v0keys, set(v.keys()))) 1199 return type(v0)(**{ 1200 key: regroup(tuple(v[key] for v in values), wrap_class) 1201 for key in v0keys 1202 }) 1203 1204 # If exactly the same object across all devices, return it unwrapped. 1205 same_id = True 1206 for v in values[1:]: 1207 if v is not v0: 1208 same_id = False 1209 break 1210 # Consider three cases where same_id is true: 1211 # * If v0 is a DistributedVariable (a MirroredVariable or 1212 # SyncOnReadVariable, and same_id means it is the same across all 1213 # devices), we want to return it. We check DistributedVariable 1214 # specifically since it can look like it has a 1215 # _distributed_container member since its members do. 1216 # * If v0 is a member of a distributed variable, in which case 1217 # hasattr(v0, "_distributed_container") is true, we want to 1218 # return the DistributedVariable that contains it using the 1219 # _distributed_container logic below. This case can trigger 1220 # same_id when there is only one device. 1221 # * In any other situation, same_id means we return v0. 1222 if same_id and (isinstance(v0, DistributedVariable) or 1223 not hasattr(v0, "_distributed_container")): 1224 return v0 1225 1226 # Detect the case where each device has a parallel component of the 1227 # same MirroredVariable (or SyncOnReadVariable). In this case we 1228 # want to return the containing MirroredVariable, after a bunch of 1229 # sanity checking. In particular, each component should have the 1230 # same container, and the devices of the variables should match the 1231 # keys of the per-replica dictionary. 1232 if hasattr(v0, "_distributed_container"): 1233 # pylint: disable=protected-access 1234 assert not isinstance(v0, MirroredVariable), ( 1235 "ids = %s, values = %s" % ([id(v) for v in values], values)) 1236 distributed_container = v0._distributed_container() 1237 assert distributed_container is not None 1238 for v in values[1:]: 1239 assert distributed_container is v._distributed_container() 1240 return distributed_container 1241 # pylint: enable=protected-access 1242 1243 return wrap_class(values) 1244 1245 1246def select_replica(replica_id, structured): 1247 """Specialize a nest of regular & per-replica values for one replica.""" 1248 1249 def _get(x): 1250 # `DistributedValues` would be sliced according to replica unless it is a 1251 # `DistributedVariable` because `DistributedVariable` can be handled 1252 # directly in the replica context. 1253 if (isinstance(x, DistributedVariable) or 1254 not isinstance(x, DistributedValues)): 1255 return x 1256 else: 1257 return x.values[replica_id] 1258 1259 return nest.map_structure(_get, structured) 1260 1261 1262def select_replica_mirrored(replica_id, structured): 1263 """Specialize a nest of regular & mirrored values for one replica.""" 1264 1265 def _get_mirrored(x): 1266 if isinstance(x, DistributedValues): 1267 if not isinstance(x, Mirrored): 1268 raise TypeError( 1269 "Expected value to be mirrored across replicas: %s in %s." % 1270 (x, structured)) 1271 return x.values[replica_id] 1272 else: 1273 return x 1274 1275 return nest.map_structure(_get_mirrored, structured) 1276 1277 1278def update_regroup(extended, updates, group): 1279 """Regroup for an update, with dependencies to ensure all updates execute.""" 1280 if not group: 1281 regrouped = regroup(updates, Mirrored) 1282 return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access 1283 1284 def _make_grouped_mirrored(values): 1285 """Convert per-replica list `values` into Mirrored type with grouping.""" 1286 if len(values) == 1: 1287 return Mirrored(values) 1288 1289 # Make sure we run all updates. Without this, something like 1290 # session.run(extended.update(...)) may only update one replica. 1291 g = control_flow_ops.group(values) 1292 1293 # If values is just ops, the grouping is enough. Everything in values 1294 # should have the same type, since we expect every replica to be performing 1295 # the same computation. 1296 if not all(tensor_util.is_tensor(v) for v in values): 1297 return g 1298 1299 # Otherwise we need tensors with the same values as `values`, but 1300 # that have a dependency on `g`. 1301 with_dep = [] 1302 for v in values: 1303 with ops.device(v.device), ops.control_dependencies([g]): 1304 with_dep.append(array_ops.identity(v)) 1305 1306 return Mirrored(with_dep) 1307 1308 return regroup(updates, _make_grouped_mirrored) 1309 1310 1311def value_container(val): 1312 """Returns the container that this per-replica `value` belongs to. 1313 1314 Args: 1315 val: A value returned by `call_for_each_replica()` or a variable created in 1316 `scope()`. 1317 1318 Returns: 1319 A container that `value` belongs to. 1320 If value does not belong to any container (including the case of 1321 container having been destroyed), returns the value itself. 1322 """ 1323 if (hasattr(val, "_distributed_container") and 1324 # DistributedVariable has _distributed_container defined 1325 # but we don't want to return it. 1326 not isinstance(val, DistributedVariable)): 1327 container = val._distributed_container() # pylint: disable=protected-access 1328 if container is not None: 1329 return container 1330 return val 1331 1332 1333class AggregatingVariable(variables_lib.Variable): 1334 """A wrapper around a variable that aggregates updates across replicas.""" 1335 1336 def __init__(self, strategy, v, aggregation): 1337 self._distribute_strategy = strategy 1338 self._v = v 1339 # NOTE: We don't use "_distributed_container" here because we don't want 1340 # to trigger that code path in regroup(). 1341 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access 1342 self._aggregation = aggregation 1343 1344 def get(self): 1345 return self._v 1346 1347 @property 1348 def distribute_strategy(self): 1349 return self._distribute_strategy 1350 1351 def __getattr__(self, name): 1352 return getattr(self._v, name) 1353 1354 def _assign_func(self, *args, **kwargs): 1355 with _enter_or_assert_strategy(self._distribute_strategy): 1356 f = kwargs.pop("f") 1357 if distribution_strategy_context.in_cross_replica_context(): 1358 if distribute_lib.get_update_replica_id() is not None: 1359 # We are calling an assign function in an update context. 1360 return f(self._v, *args, **kwargs) 1361 1362 # We are calling an assign function in cross replica context, wrap it in 1363 # an update call. 1364 return self._distribute_strategy.extended.update( 1365 self, f, args=args, kwargs=kwargs) 1366 else: 1367 replica_context = distribution_strategy_context.get_replica_context() 1368 assert replica_context 1369 # We are calling an assign function in replica context. 1370 # We reduce the value we want to assign/add/sub. More details about how 1371 # we handle the different use cases can be found in the _reduce method. 1372 # We call the function with the reduced value. 1373 if self._aggregation == vs.VariableAggregation.NONE: 1374 raise ValueError( 1375 _aggregation_error_msg.format( 1376 variable_type="AggregatingVariable")) 1377 1378 def merge_fn(strategy, value, *other_args, **other_kwargs): 1379 v = _apply_aggregation(strategy, value, self._aggregation, self) 1380 return strategy.extended.update( 1381 self, f, args=(v,) + other_args, kwargs=other_kwargs) 1382 1383 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 1384 1385 def assign_sub(self, *args, **kwargs): 1386 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 1387 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 1388 1389 def assign_add(self, *args, **kwargs): 1390 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 1391 return self._assign_func(f=assign_add_fn, *args, **kwargs) 1392 1393 def assign(self, *args, **kwargs): 1394 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 1395 return self._assign_func(f=assign_fn, *args, **kwargs) 1396 1397 @property 1398 def initializer(self): 1399 return self._v.initializer 1400 1401 def initialized_value(self): 1402 return self._v.initialized_value() 1403 1404 @property 1405 def initial_value(self): 1406 return self._v.initial_value 1407 1408 @property 1409 def op(self): 1410 return self._v.op 1411 1412 def read_value(self): 1413 return self._v.read_value() 1414 1415 def eval(self, session=None): 1416 return self._v.eval(session) 1417 1418 @property 1419 def graph(self): 1420 return self._v.graph 1421 1422 @property 1423 def device(self): 1424 return self._v.device 1425 1426 @property 1427 def shape(self): 1428 return self._v.shape 1429 1430 @property 1431 def aggregation(self): 1432 return self._aggregation 1433 1434 @property 1435 def name(self): 1436 return self._v.name 1437 1438 @property 1439 def trainable(self): 1440 return self._v.trainable 1441 1442 @property 1443 def dtype(self): 1444 return self._v.dtype 1445 1446 # TODO(josh11b): Test saving & restoring. 1447 def _gather_saveables_for_checkpoint(self): 1448 return {trackable.VARIABLE_VALUE_KEY: self._v} 1449 1450 # pylint: disable=multiple-statements 1451 def __add__(self, o): 1452 return self._v + o 1453 1454 def __radd__(self, o): 1455 return o + self._v 1456 1457 def __sub__(self, o): 1458 return self._v - o 1459 1460 def __rsub__(self, o): 1461 return o - self._v 1462 1463 def __mul__(self, o): 1464 return self._v * o 1465 1466 def __rmul__(self, o): 1467 return o * self._v 1468 1469 def __truediv__(self, o): 1470 return self._v / o 1471 1472 def __rtruediv__(self, o): 1473 return o / self._v 1474 1475 def __floordiv__(self, o): 1476 return self._v // o 1477 1478 def __rfloordiv__(self, o): 1479 return o // self._v 1480 1481 def __mod__(self, o): 1482 return self._v % o 1483 1484 def __rmod__(self, o): 1485 return o % self._v 1486 1487 def __lt__(self, o): 1488 return self._v < o 1489 1490 def __le__(self, o): 1491 return self._v <= o 1492 1493 def __gt__(self, o): 1494 return self._v > o 1495 1496 def __ge__(self, o): 1497 return self._v >= o 1498 1499 def __and__(self, o): 1500 return self._v & o 1501 1502 def __rand__(self, o): 1503 return o & self._v 1504 1505 def __or__(self, o): 1506 return self._v | o 1507 1508 def __ror__(self, o): 1509 return o | self._v 1510 1511 def __xor__(self, o): 1512 return self._v ^ o 1513 1514 def __rxor__(self, o): 1515 return o ^ self._v 1516 1517 def __getitem__(self, o): 1518 return self._v[o] 1519 1520 def __pow__(self, o, modulo=None): 1521 return pow(self._v, o, modulo) 1522 1523 def __rpow__(self, o): 1524 return pow(o, self._v) 1525 1526 def __invert__(self): 1527 return ~self._v 1528 1529 def __neg__(self): 1530 return -self._v 1531 1532 def __abs__(self): 1533 return abs(self._v) 1534 1535 def __div__(self, o): 1536 try: 1537 return self._v.__div__(o) 1538 except AttributeError: 1539 # See https://docs.python.org/3/library/constants.html#NotImplemented 1540 return NotImplemented 1541 1542 def __rdiv__(self, o): 1543 try: 1544 return self._v.__rdiv__(o) 1545 except AttributeError: 1546 # See https://docs.python.org/3/library/constants.html#NotImplemented 1547 return NotImplemented 1548 1549 def __matmul__(self, o): 1550 try: 1551 return self._v.__matmul__(o) 1552 except AttributeError: 1553 # See https://docs.python.org/3/library/constants.html#NotImplemented 1554 return NotImplemented 1555 1556 def __rmatmul__(self, o): 1557 try: 1558 return self._v.__rmatmul__(o) 1559 except AttributeError: 1560 # See https://docs.python.org/3/library/constants.html#NotImplemented 1561 return NotImplemented 1562 1563 def __str__(self): 1564 return str(self._v) 1565 1566 def __repr__(self): 1567 return repr(self._v) 1568 1569 def _should_act_as_resource_variable(self): 1570 """Pass resource_variable_ops.is_resource_variable check.""" 1571 pass 1572 1573 1574# Register a conversion function which reads the value of the variable, 1575# allowing instances of the class to be used as tensors. 1576def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): 1577 return ops.convert_to_tensor(var.get(), dtype=dtype, name=name, as_ref=as_ref) 1578 1579 1580ops.register_tensor_conversion_function(AggregatingVariable, 1581 _tensor_conversion_aggregate) 1582ops.register_dense_tensor_like_type(AggregatingVariable) 1583