• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import weakref
24
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import distribution_strategy_context
28from tensorflow.python.distribute import reduce_util
29from tensorflow.python.eager import context
30from tensorflow.python.eager import tape
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_resource_variable_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import variable_scope as vs
40from tensorflow.python.ops import variables as variables_lib
41from tensorflow.python.training import saver
42from tensorflow.python.training.tracking import base as trackable
43from tensorflow.python.util import nest
44
45
46def _get_current_replica_id_as_int():
47  """Returns the current replica ID as an integer, or `None`."""
48  replica_context = distribution_strategy_context.get_replica_context()
49  if replica_context:
50    replica_id = replica_context.replica_id_in_sync_group
51    if not isinstance(replica_id, int):
52      replica_id = tensor_util.constant_value(replica_id)
53  else:
54    replica_id = distribute_lib.get_update_replica_id()
55  return replica_id
56
57
58class DistributedValues(object):
59  """Holds a map from replica to values. Either PerReplica or Mirrored."""
60
61  def __init__(self, values):
62    self._values = tuple(values)
63
64  def get(self):
65    """Returns the value for the current device or raises a ValueError."""
66    replica_id = _get_current_replica_id_as_int()
67    if replica_id is None:
68      return self._get_cross_replica()
69    else:
70      return self._values[replica_id]
71
72  def _get_cross_replica(self):
73    raise NotImplementedError(
74        "This method should be overridden by sub-classes which support cross-"
75        "replica accesses.")
76
77  def _get_closest(self):
78    """Returns value in same replica or device if possible, else the primary."""
79    replica_id = _get_current_replica_id_as_int()
80    if replica_id is None:
81      # Try to find a value on the current device.
82      current_device = device_util.canonicalize(device_util.current())
83      for value in self._values:
84        if device_util.canonicalize(value.device) == current_device:
85          return value
86      return self.primary
87    else:
88      return self._values[replica_id]
89
90  @property
91  def primary(self):
92    """Returns a representative component."""
93    return self._values[0]
94
95  # TODO(josh11b): Replace experimental_local_results with this?
96  @property
97  def values(self):
98    return self._values
99
100  @property
101  def devices(self):
102    return tuple(v.device for v in self._values)
103
104  @property
105  def is_tensor_like(self):
106    return all(tensor_util.is_tensor(v) for v in self._values)
107
108  def __str__(self):
109    debug_str = ",\n".join(
110        "  %d: %s" % (i, v) for i, v in enumerate(self._values))
111    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
112
113  def __repr__(self):
114    debug_repr = ",\n".join(
115        "  %d: %r" % (i, v) for i, v in enumerate(self._values))
116    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
117
118
119# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
120# initialized with and use that to generate the overloaded operators here.
121# Unfortunately, Python's rules for special methods don't allow this, see
122# https://docs.python.org/3/reference/datamodel.html#special-method-names
123# "if a class defines a method named __getitem__(), and x is an instance of
124# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
125# In particular, these special methods don't go through __getattr__, and
126# it will only use those methods if they are defined in the class, not the
127# object.
128class DistributedDelegate(DistributedValues):
129  """A map from device to values; acts as the same type as the values."""
130
131  def __getattr__(self, name):
132    # The '_use_resource_variables' and the attrs starts with '_self' are used
133    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
134    # Layer tracking. At the point these attrs are queried, the variable has not
135    # been initialized. Thus it should not query those of the underlying
136    # components.
137    if name.startswith("_self_") or name in ("_use_resource_variables",
138                                             "_attribute_sentinel",
139                                             "_distributed_container"):
140      return super(DistributedDelegate, self).__getattr__(name)
141
142    # TODO(priyag): This needs to be made robust against pitfalls from mix use
143    # __getattr__ and @property. See b/120402273.
144    return getattr(self.get(), name)
145
146  def _get_as_operand(self):
147    """Returns the value for operations for the current device.
148
149    Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
150    value type within a replica context. They can, however, return a value that
151    can be used by the operations below.
152    """
153    return self.get()
154
155  # pylint: disable=multiple-statements
156  def __add__(self, o):
157    return self._get_as_operand() + o
158
159  def __radd__(self, o):
160    return o + self._get_as_operand()
161
162  def __sub__(self, o):
163    return self._get_as_operand() - o
164
165  def __rsub__(self, o):
166    return o - self._get_as_operand()
167
168  def __mul__(self, o):
169    return self._get_as_operand() * o
170
171  def __rmul__(self, o):
172    return o * self._get_as_operand()
173
174  def __truediv__(self, o):
175    return self._get_as_operand() / o
176
177  def __rtruediv__(self, o):
178    return o / self._get_as_operand()
179
180  def __floordiv__(self, o):
181    return self._get_as_operand() // o
182
183  def __rfloordiv__(self, o):
184    return o // self._get_as_operand()
185
186  def __mod__(self, o):
187    return self._get_as_operand() % o
188
189  def __rmod__(self, o):
190    return o % self._get_as_operand()
191
192  def __lt__(self, o):
193    return self._get_as_operand() < o
194
195  def __le__(self, o):
196    return self._get_as_operand() <= o
197
198  def __gt__(self, o):
199    return self._get_as_operand() > o
200
201  def __ge__(self, o):
202    return self._get_as_operand() >= o
203
204  def __and__(self, o):
205    return self._get_as_operand() & o
206
207  def __rand__(self, o):
208    return o & self._get_as_operand()
209
210  def __or__(self, o):
211    return self._get_as_operand() | o
212
213  def __ror__(self, o):
214    return o | self._get_as_operand()
215
216  def __xor__(self, o):
217    return self._get_as_operand() ^ o
218
219  def __rxor__(self, o):
220    return o ^ self._get_as_operand()
221
222  def __getitem__(self, o):
223    return self._get_as_operand()[o]
224
225  def __pow__(self, o, modulo=None):
226    return pow(self._get_as_operand(), o, modulo)
227
228  def __rpow__(self, o):
229    return pow(o, self._get_as_operand())
230
231  def __invert__(self):
232    return ~self._get_as_operand()
233
234  def __neg__(self):
235    return -self._get_as_operand()
236
237  def __abs__(self):
238    return abs(self._get_as_operand())
239
240  def __div__(self, o):
241    try:
242      return self._get_as_operand().__div__(o)
243    except AttributeError:
244      # See https://docs.python.org/3/library/constants.html#NotImplemented
245      return NotImplemented
246
247  def __rdiv__(self, o):
248    try:
249      return self._get_as_operand().__rdiv__(o)
250    except AttributeError:
251      # See https://docs.python.org/3/library/constants.html#NotImplemented
252      return NotImplemented
253
254  def __matmul__(self, o):
255    try:
256      return self._get_as_operand().__matmul__(o)
257    except AttributeError:
258      # See https://docs.python.org/3/library/constants.html#NotImplemented
259      return NotImplemented
260
261  def __rmatmul__(self, o):
262    try:
263      return self._get_as_operand().__rmatmul__(o)
264    except AttributeError:
265      # See https://docs.python.org/3/library/constants.html#NotImplemented
266      return NotImplemented
267
268  # TODO(josh11b): Even more operator overloads.
269
270
271class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
272  """Holds a map from replica to unsynchronized values."""
273
274  @property
275  def _type_spec(self):
276    return PerReplicaSpec(
277        *(type_spec.type_spec_from_value(v) for v in self._values))
278
279
280class PerReplicaSpec(type_spec.TypeSpec):
281  """Type specification for a `PerReplica`."""
282
283  __slots__ = ["_value_specs"]
284
285  value_type = property(lambda self: PerReplica)
286
287  def __init__(self, *value_specs):
288    self._value_specs = tuple(value_specs)
289
290  def _serialize(self):
291    return self._value_specs
292
293  @property
294  def _component_specs(self):
295    return self._value_specs
296
297  def _to_components(self, value):
298    replica_context = distribution_strategy_context.get_replica_context()
299    if replica_context is not None and replica_context.num_replicas_in_sync > 1:
300      raise ValueError(
301          "Flattening a PerReplica to components is not supported in replica "
302          "context.")
303    return value._values  # pylint: disable=protected-access
304
305  def _from_components(self, tensor_list):
306    return PerReplica(tensor_list)
307
308
309# Note that unlike PerReplica, Mirrored values inherit from
310# DistributedDelegate and so can be used directly in cross-replica mode.
311# TODO(tomhennigan) Should this extend CompositeTensor?
312class Mirrored(DistributedDelegate):
313  """Holds a map from replica to values which are kept in sync."""
314
315  def _get_cross_replica(self):
316    return self._get_closest()
317
318  def _as_graph_element(self):
319    obj = self.get()
320    conv_fn = getattr(obj, "_as_graph_element", None)
321    if conv_fn and callable(conv_fn):
322      return conv_fn()
323    return obj
324
325
326def _assign_on_device(device, variable, tensor):
327  with ops.device(device):
328    return variable.assign(tensor)
329
330
331def _assign_add_on_device(device, variable, tensor):
332  with ops.device(device):
333    return variable.assign_add(tensor)
334
335
336def _assign_sub_on_device(device, variable, tensor):
337  with ops.device(device):
338    return variable.assign_sub(tensor)
339
340
341def _assert_strategy(strategy):
342  if not distribution_strategy_context.has_strategy():
343    raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
344                       (strategy,))
345  current_strategy = distribution_strategy_context.get_strategy()
346  if current_strategy is not strategy:
347    raise RuntimeError(
348        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
349        (current_strategy, strategy))
350
351
352@contextlib.contextmanager
353def _enter_or_assert_strategy(strategy):
354  if not distribution_strategy_context.has_strategy():
355    with strategy.scope():
356      yield
357  else:
358    _assert_strategy(strategy)
359    yield
360
361
362DistributedVarOp = collections.namedtuple(
363    "DistributedVarOp", ["name", "graph", "traceback", "type"])
364
365
366class DistributedVariable(DistributedDelegate, variables_lib.Variable):
367  """Holds a map from replica to variables."""
368
369  # TODO(josh11b): Support changing the set of variables if e.g. if new
370  # devices are joining or a device is to leave.
371
372  def __init__(self, strategy, values):
373    self._distribute_strategy = strategy
374    super(DistributedVariable, self).__init__(values)
375    self._common_name = self.primary.name.split(":")[0]
376    # Use a weakref to make it easy to map from the contained values
377    # to the container without introducing a reference cycle.
378    for v in values:
379      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
380    # tf.keras keeps track of variables initialized using this attribute. When
381    # tf.keras gets the default session, it initializes all uninitialized vars.
382    # We need to make _keras_initialized a member of DistributedVariable because
383    # without this it will use `__getattr__` which will delegate to a component
384    # variable.
385    self._keras_initialized = False
386    # Typically, a `DistributedVariable`'s initializer is composed of the
387    # initializers of the components variables. However, in some cases, such as
388    # when restoring from a checkpoint, we may set the _initializer_op
389    # property on the entire `DistributedVariable`.
390    self._initializer_op = None
391
392  def is_initialized(self, name=None):
393    """Identifies if all the component variables are initialized.
394
395    Args:
396      name: Name of the final `logical_and` op.
397
398    Returns:
399      The op that evaluates to True or False depending on if all the
400      component variables are initialized.
401    """
402    result = self.primary.is_initialized()
403    # We iterate through the list of values except the last one to allow us to
404    # name the final `logical_and` op the same name that is passed by the user
405    # to the `is_initialized` op. For distributed variables, the
406    # `is_initialized` op is a `logical_and` op.
407    for v in self._values[1:-1]:
408      result = math_ops.logical_and(result, v.is_initialized())
409    result = math_ops.logical_and(
410        result, self._values[-1].is_initialized(), name=name)
411    return result
412
413  @property
414  def initializer(self):
415    if self._initializer_op:
416      init_op = self._initializer_op
417    else:
418      # return grouped ops of all the var initializations of component values of
419      # the mirrored variable
420      init_op = control_flow_ops.group(
421          tuple(v.initializer for v in self._values))
422    return init_op
423
424  def initialized_value(self):
425    return self._get_closest().initialized_value()
426
427  @property
428  def initial_value(self):
429    return self._get_closest().initial_value
430
431  @property
432  def graph(self):
433    return self.primary.graph
434
435  @property
436  def _shared_name(self):
437    return self._common_name
438
439  @property
440  def _unique_id(self):
441    return self.primary._unique_id  # pylint: disable=protected-access
442
443  @property
444  def _graph_key(self):
445    """Lets Optimizers know which graph this variable is from."""
446    return self.primary._graph_key  # pylint: disable=protected-access
447
448  @property
449  def name(self):
450    return self.primary.name
451
452  @property
453  def dtype(self):
454    return self.primary.dtype
455
456  @property
457  def shape(self):
458    return self.primary.shape
459
460  @property
461  def synchronization(self):
462    return self.primary.synchronization
463
464  @property
465  def handle(self):
466    replica_id = _get_current_replica_id_as_int()
467    if replica_id is None:
468      raise ValueError("`handle` is not available outside the replica context"
469                       " or a `tf.distribute.Strategy.update()` call.")
470    else:
471      return self._values[replica_id].handle
472
473  def eval(self, session=None):
474    return self._get_closest().eval(session)
475
476  @property
477  def _save_slice_info(self):
478    return self.primary._save_slice_info  # pylint: disable=protected-access
479
480  def _get_save_slice_info(self):
481    return self.primary._get_save_slice_info()  # pylint: disable=protected-access
482
483  def _set_save_slice_info(self, save_slice_info):
484    for v in self._values:
485      v._set_save_slice_info(save_slice_info)  # pylint: disable=protected-access
486
487  @property
488  def device(self):
489    return self._get_closest().device
490
491  @property
492  def trainable(self):
493    return self.primary.trainable
494
495  @property
496  def distribute_strategy(self):
497    return self._distribute_strategy
498
499  def get_shape(self):
500    return self.primary.get_shape()
501
502  def to_proto(self, export_scope=None):
503    return self.primary.to_proto(export_scope=export_scope)
504
505  @property
506  def op(self):
507    # We want cross-replica code that does some var.op.X calls
508    # to work (even if the current device isn't in self.devices), but
509    # other uses of var.op in a cross-replica context to fail.
510    if distribution_strategy_context.in_cross_replica_context():
511      return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
512                              self.primary.op.traceback, self.primary.op.type)
513    return self.get().op
514
515  @property
516  def _in_graph_mode(self):
517    return self.primary._in_graph_mode  # pylint: disable=protected-access
518
519  def read_value(self):
520    with _enter_or_assert_strategy(self._distribute_strategy):
521      return array_ops.identity(self.get())
522
523  def value(self):
524    return self._get_closest().value()
525
526  def _should_act_as_resource_variable(self):
527    """Pass resource_variable_ops.is_resource_variable check."""
528    pass
529
530
531ops.register_dense_tensor_like_type(DistributedVariable)
532
533
534@contextlib.contextmanager
535def _maybe_enter_graph(tensor):
536  # Note: might have an eager tensor but not be executing eagerly when
537  # building functions.
538  if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
539      ops.has_default_graph()):
540    yield
541  else:
542    with tensor.graph.as_default():
543      yield
544
545
546def _make_raw_assign_fn(raw_assign_fn):  # pylint: disable=missing-docstring
547
548  def assign_fn(var, value, use_locking=False, name=None, read_value=True):  # pylint: disable=missing-docstring
549    del use_locking  # Unused.
550
551    with _maybe_enter_graph(var.handle):
552      op = raw_assign_fn(
553          var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
554
555      with ops.control_dependencies([op]):
556        return var._read_variable_op() if read_value else op  # pylint: disable=protected-access
557
558  return assign_fn
559
560
561class TPUVariableMixin(object):
562  """Mixin for TPU variables."""
563
564  def __init__(self, *args, **kwargs):
565    super(TPUVariableMixin, self).__init__(*args, **kwargs)
566
567    # Handle ID is needed for `get_replicated_var_handle` to cache the variables
568    # correctly since in eager mode different variables can have the same name.
569    if ops.executing_eagerly_outside_functions():
570      self._handle_id = self._common_name + "_" + str(id(self.primary))
571    else:
572      self._handle_id = self._common_name
573
574  def __getattr__(self, name):
575    if _enclosing_tpu_context() is None:
576      return super(TPUVariableMixin, self).__getattr__(name)
577    else:
578      raise AttributeError(
579          "'{}' not accessible within a TPU context.".format(name))
580
581  def get(self):
582    if _enclosing_tpu_context() is None:
583      return super(TPUVariableMixin, self).get()
584    else:
585      raise NotImplementedError(
586          "`TPUVariableMixin.get()` is not supported within a TPU context.")
587
588  def _get_as_operand(self):
589    return self.read_value()
590
591  def _get_closest(self):
592    if _enclosing_tpu_context() is None:
593      return super(TPUVariableMixin, self)._get_closest()
594    else:
595      return self.primary
596
597  def numpy(self):
598    if context.executing_eagerly():
599      return self.read_value().numpy()
600    else:
601      raise NotImplementedError(
602          "numpy() is only available when eager execution is enabled.")
603
604  def _is_mirrored(self):
605    raise NotImplementedError(
606        "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
607
608  @property
609  def handle(self):
610    # If we're in a tpu.rewrite(), return the replicated handle.
611    tpu_context = _enclosing_tpu_context()
612    if tpu_context is None:
613      return self._get_closest().handle
614    else:
615      return tpu_context.get_replicated_var_handle(
616          self._handle_id, self._values, self._is_mirrored())
617
618  @property
619  def device(self):
620    return self.handle.device
621
622  def _read_variable_op(self):
623    if self.trainable:
624      tape.variable_accessed(self)
625    return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
626
627  def read_value(self):
628    if _enclosing_tpu_context() is None:
629      return super(TPUVariableMixin, self).read_value()
630    else:
631      return self._read_variable_op()
632
633  @property
634  def constraint(self):
635    return self.primary.constraint
636
637  def _as_graph_element(self):
638    if _enclosing_tpu_context() is None:
639      return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
640    else:
641      return None
642
643  @property
644  def op(self):
645    return DistributedVarOp(self.primary.op.name, self.primary.op.graph,
646                            self.primary.op.traceback, self.primary.op.type)
647
648  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
649    """Converts a variable to a tensor."""
650    # pylint: disable=protected-access
651    if _enclosing_tpu_context() is None:
652      return super(TPUVariableMixin, self)._dense_var_to_tensor(
653          dtype=dtype, name=name, as_ref=as_ref)
654    # pylint: enable=protected-access
655    elif dtype is not None and dtype != self.dtype:
656      return math_ops.cast(self.read_value(), dtype)
657    else:
658      return self.handle if as_ref else self.read_value()
659
660
661def _validate_colocate_extended(v, extended):
662  variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
663  if variable_strategy.extended is not extended:
664    raise ValueError(
665        "`colocate_vars_with` must only be passed a variable created in this "
666        "tf.distribute.Strategy.scope(), not %s created in scope: %s" %
667        (v, variable_strategy))
668
669
670def validate_colocate_distributed_variable(v, extended):
671  if not isinstance(v, DistributedVariable):
672    raise ValueError(
673        "`colocate_vars_with` must only be passed a variable created in this "
674        "tf.distribute.Strategy.scope(), not: %r" % (v,))
675  _validate_colocate_extended(v, extended)
676
677
678def validate_colocate(v, extended):
679  if not hasattr(v, "_distribute_strategy"):
680    raise ValueError(
681        "`colocate_vars_with` must only be passed a variable created in this "
682        "tf.distribute.Strategy.scope(), not: %r" % (v,))
683  _validate_colocate_extended(v, extended)
684
685
686def _apply_aggregation(strategy, value, aggregation, destinations):
687  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
688    return strategy.extended.broadcast_to(
689        strategy.experimental_local_results(value)[0],
690        destinations=destinations)
691  reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
692  return strategy.extended.reduce_to(reduce_op, value, destinations)
693
694
695_aggregation_error_msg = (
696    "You must specify an aggregation method to update a "
697    "{variable_type} in Replica Context. You can do so by passing "
698    "an explicit value for argument `aggregation` to tf.Variable(..)."
699    "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
700    "`tf.VariableAggregation` lists the possible aggregation methods."
701    "This is required because {variable_type} should always be "
702    "kept in sync. When updating them or assigning to them in a "
703    "replica context, we automatically try to aggregate the values "
704    "before updating the variable. For this aggregation, we need to "
705    "know the aggregation method. "
706    "Another alternative is to not try to update such "
707    "{variable_type} in replica context, but in cross replica "
708    "context. You can enter cross replica context by calling "
709    "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
710    "Inside `merge_fn`, you can then update the {variable_type} "
711    "using `tf.distribute.StrategyExtended.update()`.")
712
713
714class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
715  """Class for defining how to restore a MirroredVariable."""
716
717  def __init__(self, mirrored_variable, primary_variable, name):
718    self._mirrored_variable = mirrored_variable
719    super(_MirroredSaveable, self).__init__(primary_variable, "", name)
720
721  def restore(self, restored_tensors, restored_shapes):
722    """Restore the same value into all variables."""
723    tensor, = restored_tensors
724    return control_flow_ops.group(
725        tuple(
726            _assign_on_device(v.device, v, tensor)
727            for v in self._mirrored_variable.values))
728
729
730def create_mirrored_variable(  # pylint: disable=missing-docstring
731    strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
732  # Figure out what collections this variable should be added to.
733  # We'll add the MirroredVariable to those collections instead.
734  var_collections = kwargs.pop("collections", None)
735  if var_collections is None:
736    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
737  kwargs["collections"] = []
738
739  synchronization = kwargs.get("synchronization",
740                               vs.VariableSynchronization.ON_WRITE)
741
742  if synchronization == vs.VariableSynchronization.NONE:
743    raise ValueError(
744        "`NONE` variable synchronization mode is not supported with `Mirrored` "
745        "distribution strategy. Please change the `synchronization` for "
746        "variable: " + str(kwargs["name"]))
747  elif synchronization == vs.VariableSynchronization.ON_READ:
748    is_sync_on_read = True
749  elif synchronization in (vs.VariableSynchronization.ON_WRITE,
750                           vs.VariableSynchronization.AUTO):
751    # `AUTO` synchronization defaults to `ON_WRITE`.
752    is_sync_on_read = False
753  else:
754    raise ValueError(
755        "Invalid variable synchronization mode: %s for variable: %s" %
756        (synchronization, kwargs["name"]))
757
758  aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
759
760  if aggregation not in (vs.VariableAggregation.NONE,
761                         vs.VariableAggregation.SUM,
762                         vs.VariableAggregation.MEAN,
763                         vs.VariableAggregation.ONLY_FIRST_REPLICA):
764    raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
765                     (aggregation, kwargs["name"]))
766
767  # Ignore user-specified caching device, not needed for mirrored variables.
768  kwargs.pop("caching_device", None)
769
770  # TODO(josh11b,apassos): It would be better if variable initialization
771  # was never recorded on the tape instead of having to do this manually
772  # here.
773  with tape.stop_recording():
774    value_list = real_mirrored_creator(**kwargs)
775    var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
776    result = var_cls(strategy, value_list, aggregation)
777
778  # Add the wrapped variable to the requested collections.
779  # The handling of eager mode and the global step matches
780  # ResourceVariable._init_from_args().
781  if not context.executing_eagerly():
782    g = ops.get_default_graph()
783    # If "trainable" is True, next_creator() will add the member variables
784    # to the TRAINABLE_VARIABLES collection, so we manually remove
785    # them and replace with the MirroredVariable. We can't set
786    # "trainable" to False for next_creator() since that causes functions
787    # like implicit_gradients to skip those variables.
788    if kwargs.get("trainable", True):
789      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
790      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
791      for value in value_list:
792        for i, trainable_variable in enumerate(l):
793          if value is trainable_variable:
794            del l[i]
795            break
796
797    g.add_to_collections(var_collections, result)
798  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
799    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
800
801  return result
802
803
804class MirroredVariable(DistributedVariable, Mirrored):
805  """Holds a map from replica to variables whose values are kept in sync."""
806
807  def __init__(self, strategy, values, aggregation):
808    super(MirroredVariable, self).__init__(strategy, values)
809    self._aggregation = aggregation
810
811  # The arguments to update() are automatically unwrapped so the update()
812  # function would normally see regular variables, not MirroredVariables.
813  # However, the update function can still operate on wrapped MirroredVariables
814  # through object members, captured arguments, etc. This is more likely in an
815  # update_non_slot() function (like OptimizerV2._finish), which can
816  # update several non-slot variables in one call.
817  def _assign_func(self, *args, **kwargs):
818    with _enter_or_assert_strategy(self._distribute_strategy):
819      f = kwargs.pop("f")
820      if distribution_strategy_context.in_cross_replica_context():
821        update_replica_id = distribute_lib.get_update_replica_id()
822        if update_replica_id is not None:
823          # We are calling an assign function on the mirrored variable in an
824          # update context.
825          return f(self.values[update_replica_id], *args, **kwargs)
826
827        # We are calling assign on the mirrored variable in cross replica
828        # context, use `strategy.extended.update()` to update the variable.
829        return self._distribute_strategy.extended.update(
830            self, f, args=args, kwargs=kwargs)
831      else:
832        _assert_replica_context(self._distribute_strategy)
833        # We are calling an assign function on the mirrored variable in replica
834        # context.
835        # We reduce the value we want to assign/add/sub. More details about how
836        # we handle the different use cases can be found in the _reduce method.
837        # We call the function on each of the mirrored variables with the
838        # reduced value.
839        if self._aggregation == vs.VariableAggregation.NONE:
840          raise ValueError(
841              _aggregation_error_msg.format(variable_type="MirroredVariable"))
842
843        def merge_fn(strategy, value, *other_args, **other_kwargs):  # pylint: disable=missing-docstring
844          # Don't allow MEAN with non float dtype, since it may cause unexpected
845          # precision loss. Python3 and NumPy automatically upcast integers to
846          # float in division, but we should always preserve the type.
847          #
848          # Note that to be backward compatible we allow the case when the value
849          # is *always* the same on each replica. I.E. value is not a
850          # PerReplica. Refer to regroup() to see how values are grouped.
851          if self._aggregation == vs.VariableAggregation.MEAN and (
852              not self.dtype.is_floating) and isinstance(value, PerReplica):
853            raise ValueError(
854                "Cannot update non-float variables with "
855                "tf.VariableAggregation.MEAN aggregation in replica context. "
856                "Either change the variable dtype to float or update it in "
857                "cross-replica context.")
858
859          v = _apply_aggregation(strategy, value, self._aggregation, self)
860          return strategy.extended.update(
861              self, f, args=(v,) + other_args, kwargs=other_kwargs)
862
863        return distribution_strategy_context.get_replica_context().merge_call(
864            merge_fn, args=args, kwargs=kwargs)
865
866  def assign_sub(self, *args, **kwargs):
867    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
868    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
869
870  def assign_add(self, *args, **kwargs):
871    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
872    return self._assign_func(f=assign_add_fn, *args, **kwargs)
873
874  def assign(self, *args, **kwargs):
875    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
876    return self._assign_func(f=assign_fn, *args, **kwargs)
877
878  @property
879  def aggregation(self):
880    return self._aggregation
881
882  def _get_cross_replica(self):
883    # Return identity, to avoid directly exposing the variable to the user and
884    # allowing it to be modified by mistake.
885    return array_ops.identity(Mirrored._get_cross_replica(self))
886
887  def _as_graph_element(self):
888    return self._get_closest()._as_graph_element()  # pylint: disable=protected-access
889
890  def _gather_saveables_for_checkpoint(self):
891    """Overrides Trackable method.
892
893    This allows both name-based and object-based save and restore of
894    MirroredVariables.
895
896    Returns:
897      A dictionary mapping attribute names to `SaveableObject` factories.
898    """
899
900    def _saveable_factory(name=self._common_name):
901      return _MirroredSaveable(self, self.primary, name)
902
903    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
904
905  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
906    """Converts a variable to a tensor."""
907    # Try to avoid assignments to and other mutations of MirroredVariable
908    # state except through a DistributionStrategy.extended.update() call.
909    assert not as_ref
910    return ops.convert_to_tensor(
911        self.get(), dtype=dtype, name=name, as_ref=as_ref)
912
913
914# Register a conversion function which reads the value of the variable,
915# allowing instances of the class to be used as tensors.
916def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
917  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
918
919
920ops.register_tensor_conversion_function(MirroredVariable,
921                                        _tensor_conversion_mirrored)
922
923
924def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False):
925  return ops.convert_to_tensor(
926      value.get(), dtype=dtype, name=name, as_ref=as_ref)
927
928
929ops.register_tensor_conversion_function(Mirrored,
930                                        _tensor_conversion_mirrored_val)
931
932
933def _enclosing_tpu_context():
934  """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
935  graph = ops.get_default_graph()
936  while graph is not None:
937    # pylint: disable=protected-access
938    context_ = graph._get_control_flow_context()
939    # pylint: enable=protected-access
940    while context_ is not None:
941      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
942        return context_
943      context_ = context_.outer_context
944    # This may be a FuncGraph due to defuns or v2 control flow. We need to
945    # find the original graph with the XLAControlFlowContext.
946    graph = getattr(graph, "outer_graph", None)
947  return None
948
949
950def is_distributed_variable(v):
951  """Determine if a variable is ds variable or TPU mirrored variable."""
952  return isinstance(v, DistributedVariable)
953
954
955class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
956  """Holds a map from replica to TPU variables whose values are kept in sync."""
957
958  def _assign_func(self, *args, **kwargs):
959    with _enter_or_assert_strategy(self._distribute_strategy):
960      if (distribution_strategy_context.in_cross_replica_context() and
961          (_enclosing_tpu_context() is not None)):
962        f = kwargs.pop("f")
963        return self._distribute_strategy.extended.update(
964            self, f, args=args, kwargs=kwargs)
965      else:
966        return MirroredVariable._assign_func(self, *args, **kwargs)
967
968  def assign_sub(self, *args, **kwargs):
969    assign_sub_fn = _make_raw_assign_fn(
970        gen_resource_variable_ops.assign_sub_variable_op)
971    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
972
973  def assign_add(self, *args, **kwargs):
974    assign_add_fn = _make_raw_assign_fn(
975        gen_resource_variable_ops.assign_add_variable_op)
976    return self._assign_func(f=assign_add_fn, *args, **kwargs)
977
978  def assign(self, *args, **kwargs):
979    assign_fn = _make_raw_assign_fn(
980        gen_resource_variable_ops.assign_variable_op)
981    return self._assign_func(f=assign_fn, *args, **kwargs)
982
983  def _is_mirrored(self):
984    return True
985
986
987class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
988  """Class for defining how to restore a SyncOnReadVariable."""
989
990  def __init__(self, sync_on_read_variable, name):
991    self._sync_on_read_variable = sync_on_read_variable
992
993    # We use a callable so that we don't have to evaluate this expression
994    # in the case where we are trying to restore instead of save.
995    def tensor():
996      strategy = sync_on_read_variable._distribute_strategy  # pylint: disable=protected-access
997      return strategy.extended.read_var(sync_on_read_variable)
998
999    spec = saver.BaseSaverBuilder.SaveSpec(
1000        tensor=tensor,
1001        slice_spec="",
1002        name=name,
1003        dtype=sync_on_read_variable.dtype,
1004        device=sync_on_read_variable.primary.device)
1005    super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
1006
1007  def restore(self, restored_tensors, restored_shapes):
1008    """Restore the same value into all variables."""
1009    # To preserve the sum across save and restore, we have to divide the
1010    # total across all devices when restoring a variable that was summed
1011    # when saving.
1012    tensor, = restored_tensors
1013    if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM:
1014      tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices),
1015                             self._sync_on_read_variable.dtype)
1016    return control_flow_ops.group(
1017        tuple(
1018            _assign_on_device(v.device, v, tensor)
1019            for v in self._sync_on_read_variable.values))
1020
1021
1022def _assert_replica_context(strategy):
1023  replica_context = distribution_strategy_context.get_replica_context()
1024  if not replica_context:
1025    raise RuntimeError(
1026        "Replica-local variables may only be assigned in a replica context.")
1027  if replica_context.strategy is not strategy:
1028    raise RuntimeError(
1029        "Replica-local variables may only be assigned in a replica context.")
1030
1031
1032class SyncOnReadVariable(DistributedVariable):
1033  """Holds a map from replica to variables whose values are reduced on save."""
1034
1035  def __init__(self, strategy, values, aggregation):
1036    super(SyncOnReadVariable, self).__init__(strategy, values)
1037    self._aggregation = aggregation
1038
1039  def assign_sub(self, *args, **kwargs):
1040    with _enter_or_assert_strategy(self._distribute_strategy):
1041      if distribution_strategy_context.in_cross_replica_context():
1042        if self._aggregation == vs.VariableAggregation.SUM:
1043          raise ValueError(
1044              "SyncOnReadVariable does not support `assign_sub` in "
1045              "cross-replica context when aggregation is set to "
1046              "`tf.VariableAggregation.SUM`.")
1047        return control_flow_ops.group(
1048            tuple(
1049                _assign_sub_on_device(v.device, v, args[0])
1050                for v in self._values))
1051      else:
1052        return self.get().assign_sub(*args, **kwargs)
1053
1054  def assign_add(self, *args, **kwargs):
1055    with _enter_or_assert_strategy(self._distribute_strategy):
1056      if distribution_strategy_context.in_cross_replica_context():
1057        if self._aggregation == vs.VariableAggregation.SUM:
1058          raise ValueError(
1059              "SyncOnReadVariable does not support `assign_add` in "
1060              "cross-replica context when aggregation is set to "
1061              "`tf.VariableAggregation.SUM`.")
1062        return control_flow_ops.group(
1063            tuple(
1064                _assign_add_on_device(v.device, v, args[0])
1065                for v in self._values))
1066      else:
1067        return self.get().assign_add(*args, **kwargs)
1068
1069  def assign(self, *args, **kwargs):
1070    with _enter_or_assert_strategy(self._distribute_strategy):
1071      if distribution_strategy_context.in_cross_replica_context():
1072        # To preserve the sum across save and restore, we have to divide the
1073        # total across all devices when restoring a variable that was summed
1074        # when saving.
1075        tensor = args[0]
1076        if self._aggregation == vs.VariableAggregation.SUM:
1077          tensor = math_ops.cast(tensor / len(self._values), self.dtype)
1078        return control_flow_ops.group(
1079            tuple(_assign_on_device(v.device, v, tensor) for v in self._values))
1080      else:
1081        return self.get().assign(*args, **kwargs)
1082
1083  @property
1084  def aggregation(self):
1085    return self._aggregation
1086
1087  def _get_cross_replica(self):
1088    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1089      return self.primary
1090
1091    with _enter_or_assert_strategy(self._distribute_strategy):
1092      return self._distribute_strategy.reduce(
1093          reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
1094          self,
1095          axis=None)
1096
1097  def _as_graph_element(self):
1098    # pylint: disable=protected-access
1099    if distribution_strategy_context.in_cross_replica_context():
1100      return self._get_cross_replica()
1101    return self.get()._as_graph_element()
1102
1103  def _gather_saveables_for_checkpoint(self):
1104    """Overrides Trackable method.
1105
1106    This allows both name-based and object-based save and restore of
1107    `SyncOnReadVariable`s.
1108
1109    Returns:
1110      A dictionary mapping attribute names to `SaveableObject` factories.
1111    """
1112
1113    def _saveable_factory(name=self._common_name):
1114      return _SyncOnReadSaveable(self, name)
1115
1116    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1117
1118  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1119    """Converts a variable to a tensor."""
1120    return ops.convert_to_tensor(
1121        self.get(), dtype=dtype, name=name, as_ref=as_ref)
1122
1123
1124# Register a conversion function for SyncOnReadVariable which allows as_ref to
1125# be true.
1126def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
1127  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1128
1129
1130ops.register_tensor_conversion_function(SyncOnReadVariable,
1131                                        _tensor_conversion_sync_on_read)
1132
1133
1134class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
1135  """Holds a map from replica to variables whose values are reduced on save."""
1136
1137  def assign_sub(self, *args, **kwargs):
1138    if _enclosing_tpu_context() is None:
1139      return SyncOnReadVariable.assign_sub(self, *args, **kwargs)
1140    else:
1141      return _make_raw_assign_fn(
1142          gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
1143                                                            **kwargs)
1144
1145  def assign_add(self, *args, **kwargs):
1146    if _enclosing_tpu_context() is None:
1147      return SyncOnReadVariable.assign_add(self, *args, **kwargs)
1148    else:
1149      return _make_raw_assign_fn(
1150          gen_resource_variable_ops.assign_add_variable_op)(self, *args,
1151                                                            **kwargs)
1152
1153  def assign(self, *args, **kwargs):
1154    if _enclosing_tpu_context() is None:
1155      return SyncOnReadVariable.assign(self, *args, **kwargs)
1156    else:
1157      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
1158          self, *args, **kwargs)
1159
1160  def _is_mirrored(self):
1161    return False
1162
1163
1164def regroup(values, wrap_class=PerReplica):
1165  """Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
1166  v0 = values[0]
1167
1168  if isinstance(v0, list):
1169    for v in values[1:]:
1170      assert isinstance(v, list)
1171      assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
1172                                 (len(v), len(v0), v, v0))
1173    return [
1174        regroup(tuple(v[i] for v in values), wrap_class)
1175        for i in range(len(v0))
1176    ]
1177
1178  if isinstance(v0, tuple):
1179    for v in values[1:]:
1180      assert isinstance(v, tuple)
1181      assert len(v) == len(v0)
1182    regrouped_tuple = tuple(
1183        regroup(tuple(v[i] for v in values), wrap_class)
1184        for i in range(len(v0)))
1185    if hasattr(v0, "_fields"):
1186      # This tuple is in fact a namedtuple! Create a new namedtuple instance
1187      # and initialize it with the regrouped values:
1188      assert hasattr(type(v0), "_make")
1189      return type(v0)._make(regrouped_tuple)
1190    else:
1191      return regrouped_tuple
1192
1193  if isinstance(v0, dict):
1194    v0keys = set(v0.keys())
1195    for v in values[1:]:
1196      assert isinstance(v, dict), ("v[0]: %r  v[i]: %r" % (v0, v))
1197      assert set(v.keys()) == v0keys, ("v[0].keys: %s  v[i].keys: %s" %
1198                                       (v0keys, set(v.keys())))
1199    return type(v0)(**{
1200        key: regroup(tuple(v[key] for v in values), wrap_class)
1201        for key in v0keys
1202    })
1203
1204  # If exactly the same object across all devices, return it unwrapped.
1205  same_id = True
1206  for v in values[1:]:
1207    if v is not v0:
1208      same_id = False
1209      break
1210  # Consider three cases where same_id is true:
1211  # * If v0 is a DistributedVariable (a MirroredVariable or
1212  #   SyncOnReadVariable, and same_id means it is the same across all
1213  #   devices), we want to return it. We check DistributedVariable
1214  #   specifically since it can look like it has a
1215  #   _distributed_container member since its members do.
1216  # * If v0 is a member of a distributed variable, in which case
1217  #   hasattr(v0, "_distributed_container") is true, we want to
1218  #   return the DistributedVariable that contains it using the
1219  #   _distributed_container logic below. This case can trigger
1220  #   same_id when there is only one device.
1221  # * In any other situation, same_id means we return v0.
1222  if same_id and (isinstance(v0, DistributedVariable) or
1223                  not hasattr(v0, "_distributed_container")):
1224    return v0
1225
1226  # Detect the case where each device has a parallel component of the
1227  # same MirroredVariable (or SyncOnReadVariable). In this case we
1228  # want to return the containing MirroredVariable, after a bunch of
1229  # sanity checking. In particular, each component should have the
1230  # same container, and the devices of the variables should match the
1231  # keys of the per-replica dictionary.
1232  if hasattr(v0, "_distributed_container"):
1233    # pylint: disable=protected-access
1234    assert not isinstance(v0, MirroredVariable), (
1235        "ids = %s, values = %s" % ([id(v) for v in values], values))
1236    distributed_container = v0._distributed_container()
1237    assert distributed_container is not None
1238    for v in values[1:]:
1239      assert distributed_container is v._distributed_container()
1240    return distributed_container
1241  # pylint: enable=protected-access
1242
1243  return wrap_class(values)
1244
1245
1246def select_replica(replica_id, structured):
1247  """Specialize a nest of regular & per-replica values for one replica."""
1248
1249  def _get(x):
1250    # `DistributedValues` would be sliced according to replica unless it is a
1251    # `DistributedVariable` because `DistributedVariable` can be handled
1252    # directly in the replica context.
1253    if (isinstance(x, DistributedVariable) or
1254        not isinstance(x, DistributedValues)):
1255      return x
1256    else:
1257      return x.values[replica_id]
1258
1259  return nest.map_structure(_get, structured)
1260
1261
1262def select_replica_mirrored(replica_id, structured):
1263  """Specialize a nest of regular & mirrored values for one replica."""
1264
1265  def _get_mirrored(x):
1266    if isinstance(x, DistributedValues):
1267      if not isinstance(x, Mirrored):
1268        raise TypeError(
1269            "Expected value to be mirrored across replicas: %s in %s." %
1270            (x, structured))
1271      return x.values[replica_id]
1272    else:
1273      return x
1274
1275  return nest.map_structure(_get_mirrored, structured)
1276
1277
1278def update_regroup(extended, updates, group):
1279  """Regroup for an update, with dependencies to ensure all updates execute."""
1280  if not group:
1281    regrouped = regroup(updates, Mirrored)
1282    return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
1283
1284  def _make_grouped_mirrored(values):
1285    """Convert per-replica list `values` into Mirrored type with grouping."""
1286    if len(values) == 1:
1287      return Mirrored(values)
1288
1289    # Make sure we run all updates. Without this, something like
1290    # session.run(extended.update(...)) may only update one replica.
1291    g = control_flow_ops.group(values)
1292
1293    # If values is just ops, the grouping is enough. Everything in values
1294    # should have the same type, since we expect every replica to be performing
1295    # the same computation.
1296    if not all(tensor_util.is_tensor(v) for v in values):
1297      return g
1298
1299    # Otherwise we need tensors with the same values as `values`, but
1300    # that have a dependency on `g`.
1301    with_dep = []
1302    for v in values:
1303      with ops.device(v.device), ops.control_dependencies([g]):
1304        with_dep.append(array_ops.identity(v))
1305
1306    return Mirrored(with_dep)
1307
1308  return regroup(updates, _make_grouped_mirrored)
1309
1310
1311def value_container(val):
1312  """Returns the container that this per-replica `value` belongs to.
1313
1314  Args:
1315    val: A value returned by `call_for_each_replica()` or a variable created in
1316      `scope()`.
1317
1318  Returns:
1319    A container that `value` belongs to.
1320    If value does not belong to any container (including the case of
1321    container having been destroyed), returns the value itself.
1322  """
1323  if (hasattr(val, "_distributed_container") and
1324      # DistributedVariable has _distributed_container defined
1325      # but we don't want to return it.
1326      not isinstance(val, DistributedVariable)):
1327    container = val._distributed_container()  # pylint: disable=protected-access
1328    if container is not None:
1329      return container
1330  return val
1331
1332
1333class AggregatingVariable(variables_lib.Variable):
1334  """A wrapper around a variable that aggregates updates across replicas."""
1335
1336  def __init__(self, strategy, v, aggregation):
1337    self._distribute_strategy = strategy
1338    self._v = v
1339    # NOTE: We don't use "_distributed_container" here because we don't want
1340    # to trigger that code path in regroup().
1341    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
1342    self._aggregation = aggregation
1343
1344  def get(self):
1345    return self._v
1346
1347  @property
1348  def distribute_strategy(self):
1349    return self._distribute_strategy
1350
1351  def __getattr__(self, name):
1352    return getattr(self._v, name)
1353
1354  def _assign_func(self, *args, **kwargs):
1355    with _enter_or_assert_strategy(self._distribute_strategy):
1356      f = kwargs.pop("f")
1357      if distribution_strategy_context.in_cross_replica_context():
1358        if distribute_lib.get_update_replica_id() is not None:
1359          # We are calling an assign function in an update context.
1360          return f(self._v, *args, **kwargs)
1361
1362        # We are calling an assign function in cross replica context, wrap it in
1363        # an update call.
1364        return self._distribute_strategy.extended.update(
1365            self, f, args=args, kwargs=kwargs)
1366      else:
1367        replica_context = distribution_strategy_context.get_replica_context()
1368        assert replica_context
1369        # We are calling an assign function in replica context.
1370        # We reduce the value we want to assign/add/sub. More details about how
1371        # we handle the different use cases can be found in the _reduce method.
1372        # We call the function with the reduced value.
1373        if self._aggregation == vs.VariableAggregation.NONE:
1374          raise ValueError(
1375              _aggregation_error_msg.format(
1376                  variable_type="AggregatingVariable"))
1377
1378        def merge_fn(strategy, value, *other_args, **other_kwargs):
1379          v = _apply_aggregation(strategy, value, self._aggregation, self)
1380          return strategy.extended.update(
1381              self, f, args=(v,) + other_args, kwargs=other_kwargs)
1382
1383        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
1384
1385  def assign_sub(self, *args, **kwargs):
1386    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
1387    return self._assign_func(f=assign_sub_fn, *args, **kwargs)
1388
1389  def assign_add(self, *args, **kwargs):
1390    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
1391    return self._assign_func(f=assign_add_fn, *args, **kwargs)
1392
1393  def assign(self, *args, **kwargs):
1394    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
1395    return self._assign_func(f=assign_fn, *args, **kwargs)
1396
1397  @property
1398  def initializer(self):
1399    return self._v.initializer
1400
1401  def initialized_value(self):
1402    return self._v.initialized_value()
1403
1404  @property
1405  def initial_value(self):
1406    return self._v.initial_value
1407
1408  @property
1409  def op(self):
1410    return self._v.op
1411
1412  def read_value(self):
1413    return self._v.read_value()
1414
1415  def eval(self, session=None):
1416    return self._v.eval(session)
1417
1418  @property
1419  def graph(self):
1420    return self._v.graph
1421
1422  @property
1423  def device(self):
1424    return self._v.device
1425
1426  @property
1427  def shape(self):
1428    return self._v.shape
1429
1430  @property
1431  def aggregation(self):
1432    return self._aggregation
1433
1434  @property
1435  def name(self):
1436    return self._v.name
1437
1438  @property
1439  def trainable(self):
1440    return self._v.trainable
1441
1442  @property
1443  def dtype(self):
1444    return self._v.dtype
1445
1446  # TODO(josh11b): Test saving & restoring.
1447  def _gather_saveables_for_checkpoint(self):
1448    return {trackable.VARIABLE_VALUE_KEY: self._v}
1449
1450  # pylint: disable=multiple-statements
1451  def __add__(self, o):
1452    return self._v + o
1453
1454  def __radd__(self, o):
1455    return o + self._v
1456
1457  def __sub__(self, o):
1458    return self._v - o
1459
1460  def __rsub__(self, o):
1461    return o - self._v
1462
1463  def __mul__(self, o):
1464    return self._v * o
1465
1466  def __rmul__(self, o):
1467    return o * self._v
1468
1469  def __truediv__(self, o):
1470    return self._v / o
1471
1472  def __rtruediv__(self, o):
1473    return o / self._v
1474
1475  def __floordiv__(self, o):
1476    return self._v // o
1477
1478  def __rfloordiv__(self, o):
1479    return o // self._v
1480
1481  def __mod__(self, o):
1482    return self._v % o
1483
1484  def __rmod__(self, o):
1485    return o % self._v
1486
1487  def __lt__(self, o):
1488    return self._v < o
1489
1490  def __le__(self, o):
1491    return self._v <= o
1492
1493  def __gt__(self, o):
1494    return self._v > o
1495
1496  def __ge__(self, o):
1497    return self._v >= o
1498
1499  def __and__(self, o):
1500    return self._v & o
1501
1502  def __rand__(self, o):
1503    return o & self._v
1504
1505  def __or__(self, o):
1506    return self._v | o
1507
1508  def __ror__(self, o):
1509    return o | self._v
1510
1511  def __xor__(self, o):
1512    return self._v ^ o
1513
1514  def __rxor__(self, o):
1515    return o ^ self._v
1516
1517  def __getitem__(self, o):
1518    return self._v[o]
1519
1520  def __pow__(self, o, modulo=None):
1521    return pow(self._v, o, modulo)
1522
1523  def __rpow__(self, o):
1524    return pow(o, self._v)
1525
1526  def __invert__(self):
1527    return ~self._v
1528
1529  def __neg__(self):
1530    return -self._v
1531
1532  def __abs__(self):
1533    return abs(self._v)
1534
1535  def __div__(self, o):
1536    try:
1537      return self._v.__div__(o)
1538    except AttributeError:
1539      # See https://docs.python.org/3/library/constants.html#NotImplemented
1540      return NotImplemented
1541
1542  def __rdiv__(self, o):
1543    try:
1544      return self._v.__rdiv__(o)
1545    except AttributeError:
1546      # See https://docs.python.org/3/library/constants.html#NotImplemented
1547      return NotImplemented
1548
1549  def __matmul__(self, o):
1550    try:
1551      return self._v.__matmul__(o)
1552    except AttributeError:
1553      # See https://docs.python.org/3/library/constants.html#NotImplemented
1554      return NotImplemented
1555
1556  def __rmatmul__(self, o):
1557    try:
1558      return self._v.__rmatmul__(o)
1559    except AttributeError:
1560      # See https://docs.python.org/3/library/constants.html#NotImplemented
1561      return NotImplemented
1562
1563  def __str__(self):
1564    return str(self._v)
1565
1566  def __repr__(self):
1567    return repr(self._v)
1568
1569  def _should_act_as_resource_variable(self):
1570    """Pass resource_variable_ops.is_resource_variable check."""
1571    pass
1572
1573
1574# Register a conversion function which reads the value of the variable,
1575# allowing instances of the class to be used as tensors.
1576def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
1577  return ops.convert_to_tensor(var.get(), dtype=dtype, name=name, as_ref=as_ref)
1578
1579
1580ops.register_tensor_conversion_function(AggregatingVariable,
1581                                        _tensor_conversion_aggregate)
1582ops.register_dense_tensor_like_type(AggregatingVariable)
1583