• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import weakref
23
24from tensorflow.python.distribute import device_util
25from tensorflow.python.distribute import distribute_lib
26from tensorflow.python.distribute import distribution_strategy_context as ds_context
27from tensorflow.python.distribute import packed_distributed_variable as packed
28from tensorflow.python.distribute import reduce_util
29from tensorflow.python.distribute import values_util
30from tensorflow.python.eager import context
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import resource_variable_ops
39from tensorflow.python.ops import variable_scope as vs
40from tensorflow.python.ops import variables as variables_lib
41from tensorflow.python.saved_model import save_context
42from tensorflow.python.training.saving import saveable_object
43from tensorflow.python.training.tracking import base as trackable
44from tensorflow.python.types import core
45from tensorflow.python.util.tf_export import tf_export
46
47
48def _on_write_update_replica(var, update_fn, value, **kwargs):
49  """Updates variables with ON_WRITE synchronization in replica context."""
50  if var.aggregation == vs.VariableAggregation.NONE:
51    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
52
53  if not ds_context.get_strategy().extended._use_merge_call():  # pylint: disable=protected-access
54    # Don't allow MEAN with non float dtype, since it may cause unexpected
55    # precision loss. Python3 and NumPy automatically upcast integers to
56    # float in division, but we should always preserve the type.
57    if var.aggregation == vs.VariableAggregation.MEAN and (
58        not var.dtype.is_floating) and tensor_util.is_tf_type(value):
59      raise ValueError(
60          "Cannot update non-float variables with "
61          "tf.VariableAggregation.MEAN aggregation in replica context. "
62          "Either change the variable dtype to float or update it in "
63          "cross-replica context.")
64
65    aggregated_value = apply_aggregation_replica_context(
66        value, var.aggregation, var)
67    values_util.mark_as_unsaveable()
68
69    return ds_context.get_replica_context()._update(  # pylint: disable=protected-access
70        var,
71        update_fn,
72        args=(aggregated_value,),
73        kwargs=kwargs,
74        group=True)
75
76  else:
77
78    def merge_fn(strategy, value, **kwargs):
79      """Aggregate values and update all variables in cross replica context."""
80      # Don't allow MEAN with non float dtype, since it may cause unexpected
81      # precision loss. Python3 and NumPy automatically upcast integers to
82      # float in division, but we should always preserve the type.
83      #
84      # Note that to be backward compatible we allow the case when the value
85      # is *always* the same on each replica. I.E. value is not a
86      # PerReplica. Refer to regroup() to see how values are grouped.
87      if var.aggregation == vs.VariableAggregation.MEAN and (
88          not var.dtype.is_floating) and isinstance(value, PerReplica):
89        raise ValueError(
90            "Cannot update non-float variables with "
91            "tf.VariableAggregation.MEAN aggregation in replica context. "
92            "Either change the variable dtype to float or update it in "
93            "cross-replica context.")
94
95      assert strategy == var.distribute_strategy
96      v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
97      return var._update_cross_replica(update_fn, v, **kwargs)  # pylint: disable=protected-access
98
99    return ds_context.get_replica_context().merge_call(
100        merge_fn, args=(value,), kwargs=kwargs)
101
102
103def apply_aggregation_replica_context(value, aggregation, destinations):
104  """Aggregate `value` to `destinations` as specified by `aggregation`."""
105  # if it is a python literal, return without aggregation
106  if isinstance(value, DistributedValues):
107    raise TypeError(
108        "Cannot use DistributedValues to update variables in replica context.")
109  if not tensor_util.is_tf_type(value):
110    return value
111
112  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
113    # Switch to cross-replica context to broadcast
114    def merge_fn(strategy, value):
115      return strategy.extended.broadcast_to(
116          strategy.experimental_local_results(value)[0],
117          destinations=destinations)
118
119    return ds_context.get_replica_context().merge_call(merge_fn, args=(value,))
120
121  else:
122    reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
123    aggregated_value = ds_context.get_strategy(  # pylint: disable=protected-access
124    ).extended._replica_ctx_all_reduce(reduce_op, value)
125    return aggregated_value
126
127
128@tf_export("distribute.DistributedValues", v1=[])
129class DistributedValues(object):
130  """Base class for representing distributed values.
131
132  A subclass instance of `tf.distribute.DistributedValues` is created when
133  creating variables within a distribution strategy, iterating a
134  `tf.distribute.DistributedDataset` or through `tf.distribute.Strategy.run`.
135  This base class should never be instantiated directly.
136  `tf.distribute.DistributedValues` contains a value per replica. Depending on
137  the subclass, the values could either be synced on update, synced on demand,
138  or never synced.
139
140  `tf.distribute.DistributedValues` can be reduced to obtain single value across
141  replicas, as input into `tf.distribute.Strategy.run` or the per-replica values
142  inspected using `tf.distribute.Strategy.experimental_local_results`.
143
144  Example usage:
145
146  1. Created from a `tf.distribute.DistributedDataset`:
147
148  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
149  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
150  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
151  >>> distributed_values = next(dataset_iterator)
152
153  2. Returned by `run`:
154
155  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
156  >>> @tf.function
157  ... def run():
158  ...   ctx = tf.distribute.get_replica_context()
159  ...   return ctx.replica_id_in_sync_group
160  >>> distributed_values = strategy.run(run)
161
162  3. As input into `run`:
163
164  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
165  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
166  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
167  >>> distributed_values = next(dataset_iterator)
168  >>> @tf.function
169  ... def run(input):
170  ...   return input + 1.0
171  >>> updated_value = strategy.run(run, args=(distributed_values,))
172
173  4. Reduce value:
174
175  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
176  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
177  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
178  >>> distributed_values = next(dataset_iterator)
179  >>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
180  ...                                 distributed_values,
181  ...                                 axis = 0)
182
183  5. Inspect local replica values:
184
185  >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
186  >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
187  >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
188  >>> per_replica_values = strategy.experimental_local_results(
189  ...    distributed_values)
190  >>> per_replica_values
191  (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
192   <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
193
194  """
195
196  def __init__(self, values):
197    """Should only be called by subclass __init__."""
198    self._values = tuple(values)
199
200  def _get(self):
201    """Returns the value for the current device or raises a ValueError."""
202    replica_id = values_util.get_current_replica_id_as_int()
203    if replica_id is None:
204      return self._get_cross_replica()
205    else:
206      return self._values[replica_id]
207
208  def _get_cross_replica(self):
209    raise NotImplementedError(
210        "DistributedValues._get_cross_replica should be implemented by "
211        "sub-classes which support cross-replica accesses.")
212
213  def _get_on_device_or_primary(self):
214    """Returns value in same replica or device if possible, else the _primary."""
215    replica_id = values_util.get_current_replica_id_as_int()
216    if replica_id is None:
217      # Try to find a value on the current device.
218      current_device = device_util.canonicalize(device_util.current())
219      for value in self._values:
220        if device_util.canonicalize(value.device) == current_device:
221          return value
222      return self._primary
223    else:
224      return self._values[replica_id]
225
226  @property
227  def _primary(self):
228    """Returns a representative component."""
229    return self._values[0]
230
231  @property
232  def _devices(self):
233    return tuple(v.device for v in self._values)
234
235  def __str__(self):
236    debug_str = ",\n".join(
237        "  %d: %s" % (i, v) for i, v in enumerate(self._values))
238    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
239
240  def __repr__(self):
241    debug_repr = ",\n".join(
242        "  %d: %r" % (i, v) for i, v in enumerate(self._values))
243    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
244
245
246# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
247# initialized with and use that to generate the overloaded operators here.
248# Unfortunately, Python's rules for special methods don't allow this, see
249# https://docs.python.org/3/reference/datamodel.html#special-method-names
250# "if a class defines a method named __getitem__(), and x is an instance of
251# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
252# In particular, these special methods don't go through __getattr__, and
253# it will only use those methods if they are defined in the class, not the
254# object.
255class DistributedDelegate(DistributedValues):
256  """A map from device to values; acts as the same type as the values."""
257
258  def __getattr__(self, name):
259    # The '_use_resource_variables' and the attrs starts with '_self' are used
260    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
261    # Layer tracking. At the point these attrs are queried, the variable has not
262    # been initialized. Thus it should not query those of the underlying
263    # components.
264    if name.startswith("_self_") or name in ("_use_resource_variables",
265                                             "_attribute_sentinel",
266                                             "_distributed_container"):
267      return super(DistributedDelegate, self).__getattr__(name)
268
269    # This allows copy.copy(DistributedDelegate). When copying an object,
270    # copy.copy doesn't invoke its __init__ method, instead it makes a new
271    # empty object, then copies the attributes over. copy.copy looks for
272    # attributes like "__getstate__" in case the object implements its custom
273    # copying. Since DistributedDelegate doesn't have those attributes defined,
274    # __getattr__ will be invoked, which tries to access "_values" attributes,
275    # but that doesn't exist either because this is an empty object, and again
276    # __getattr__ is invoked, leading to an infinite recursion.
277    if name == "_values":
278      raise AttributeError()
279
280    # TODO(priyag): This needs to be made robust against pitfalls from mix use
281    # __getattr__ and @property. See b/120402273.
282    return getattr(self._get(), name)
283
284  @property
285  def values(self):
286    """Returns the per replica values."""
287    return self._values
288
289  def _get_as_operand(self):
290    """Returns the value for operations for the current device.
291
292    Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
293    value type within a replica context. They can, however, return a value that
294    can be used by the operations below.
295    """
296    return self._get()
297
298  # pylint: disable=multiple-statements
299  def __add__(self, o):
300    return self._get_as_operand() + o
301
302  def __radd__(self, o):
303    return o + self._get_as_operand()
304
305  def __sub__(self, o):
306    return self._get_as_operand() - o
307
308  def __rsub__(self, o):
309    return o - self._get_as_operand()
310
311  def __mul__(self, o):
312    return self._get_as_operand() * o
313
314  def __rmul__(self, o):
315    return o * self._get_as_operand()
316
317  def __truediv__(self, o):
318    return self._get_as_operand() / o
319
320  def __rtruediv__(self, o):
321    return o / self._get_as_operand()
322
323  def __floordiv__(self, o):
324    return self._get_as_operand() // o
325
326  def __rfloordiv__(self, o):
327    return o // self._get_as_operand()
328
329  def __mod__(self, o):
330    return self._get_as_operand() % o
331
332  def __rmod__(self, o):
333    return o % self._get_as_operand()
334
335  def __lt__(self, o):
336    return self._get_as_operand() < o
337
338  def __le__(self, o):
339    return self._get_as_operand() <= o
340
341  def __gt__(self, o):
342    return self._get_as_operand() > o
343
344  def __ge__(self, o):
345    return self._get_as_operand() >= o
346
347  def __and__(self, o):
348    return self._get_as_operand() & o
349
350  def __rand__(self, o):
351    return o & self._get_as_operand()
352
353  def __or__(self, o):
354    return self._get_as_operand() | o
355
356  def __ror__(self, o):
357    return o | self._get_as_operand()
358
359  def __xor__(self, o):
360    return self._get_as_operand() ^ o
361
362  def __rxor__(self, o):
363    return o ^ self._get_as_operand()
364
365  def __getitem__(self, o):
366    return self._get_as_operand()[o]
367
368  def __pow__(self, o, modulo=None):
369    return pow(self._get_as_operand(), o, modulo)
370
371  def __rpow__(self, o):
372    return pow(o, self._get_as_operand())
373
374  def __invert__(self):
375    return ~self._get_as_operand()
376
377  def __neg__(self):
378    return -self._get_as_operand()
379
380  def __abs__(self):
381    return abs(self._get_as_operand())
382
383  def __div__(self, o):
384    try:
385      return self._get_as_operand().__div__(o)
386    except AttributeError:
387      # See https://docs.python.org/3/library/constants.html#NotImplemented
388      return NotImplemented
389
390  def __rdiv__(self, o):
391    try:
392      return self._get_as_operand().__rdiv__(o)
393    except AttributeError:
394      # See https://docs.python.org/3/library/constants.html#NotImplemented
395      return NotImplemented
396
397  def __matmul__(self, o):
398    try:
399      return self._get_as_operand().__matmul__(o)
400    except AttributeError:
401      # See https://docs.python.org/3/library/constants.html#NotImplemented
402      return NotImplemented
403
404  def __rmatmul__(self, o):
405    try:
406      return self._get_as_operand().__rmatmul__(o)
407    except AttributeError:
408      # See https://docs.python.org/3/library/constants.html#NotImplemented
409      return NotImplemented
410
411  # TODO(josh11b): Even more operator overloads.
412
413
414class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
415  """Holds a map from replica to unsynchronized values."""
416
417  @property
418  def _type_spec(self):
419    return PerReplicaSpec(
420        *(type_spec.type_spec_from_value(v) for v in self._values))
421
422  @property
423  def values(self):
424    """Returns the per replica values."""
425    return self._values
426
427
428class PerReplicaSpec(type_spec.TypeSpec):
429  """Type specification for a `PerReplica`."""
430
431  __slots__ = ["_value_specs"]
432
433  value_type = property(lambda self: PerReplica)
434
435  def __init__(self, *value_specs):
436    self._value_specs = tuple(value_specs)
437
438  def _serialize(self):
439    return self._value_specs
440
441  @property
442  def _component_specs(self):
443    return self._value_specs
444
445  def _to_components(self, value):
446    replica_context = ds_context.get_replica_context()
447    if replica_context is not None and replica_context.num_replicas_in_sync > 1:
448      raise ValueError(
449          "Flattening a PerReplica to components is not supported in replica "
450          "context.")
451    return value._values  # pylint: disable=protected-access
452
453  def _from_components(self, tensor_list):
454    return PerReplica(tensor_list)
455
456
457# Note that unlike PerReplica, Mirrored values inherit from
458# DistributedDelegate and so can be used directly in cross-replica mode.
459# TODO(tomhennigan) Should this extend CompositeTensor?
460class Mirrored(DistributedDelegate):
461  """Holds a map from replica to values which are kept in sync."""
462
463  def _get_cross_replica(self):
464    return self._get_on_device_or_primary()
465
466  def _as_graph_element(self):
467    obj = self._get()
468    conv_fn = getattr(obj, "_as_graph_element", None)
469    if conv_fn and callable(conv_fn):
470      return conv_fn()
471    return obj
472
473
474class DistributedVarOp(object):
475  """A class that looks like `tf.Operation`."""
476
477  def __init__(self, name, graph, traceback, typ):
478    self.name = name
479    self.graph = graph
480    self.traceback = traceback
481    self.type = typ
482
483  def __eq__(self, o):
484    if not isinstance(o, self.__class__):
485      raise NotImplementedError
486    return (self.name == o.name and self.graph == o.graph and
487            self.traceback == o.traceback and self.type == o.type)
488
489  def __hash__(self):
490    return hash((self.name, self.graph, tuple(self.traceback), self.type))
491
492
493class DistributedVariable(DistributedDelegate, variables_lib.Variable,
494                          core.Tensor):
495  """Holds a map from replica to variables."""
496
497  def __init__(self, strategy, values, aggregation, var_policy=None):
498    if (aggregation == variables_lib.VariableAggregation.MEAN and
499        not values[0].dtype.is_floating):
500      raise ValueError(
501          "creating distributed tf.Variable with aggregation=MEAN and a "
502          "non-floating dtype is not supported, please use a different "
503          "aggregation or dtype")
504    self._distribute_strategy = strategy
505    self._aggregation = aggregation
506    super(DistributedVariable, self).__init__(values)
507    self._common_name = self._primary.name.split(":")[0]
508    # Use a weakref to make it easy to map from the contained values
509    # to the container without introducing a reference cycle.
510    for v in values:
511      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
512
513    # Packed variable is used to reduce the overhead of function execution.
514    # For a DistributedVariable, only one variable handle is captured into a
515    # function graph. It's only supported in eager mode.
516    if ops.executing_eagerly_outside_functions() and getattr(
517        strategy, "_enable_packed_variable_in_eager_mode", False):
518      name = "%s/packed/" % self._common_name
519      self._packed_var = packed.PackedDistributedVariable(values, name=name)
520    else:
521      self._packed_var = None
522
523    # tf.keras keeps track of variables initialized using this attribute. When
524    # tf.keras gets the default session, it initializes all uninitialized vars.
525    # We need to make _keras_initialized a member of DistributedVariable because
526    # without this it will use `__getattr__` which will delegate to a component
527    # variable.
528    self._keras_initialized = False
529    # Typically, a `DistributedVariable`'s initializer is composed of the
530    # initializers of the components variables. However, in some cases, such as
531    # when restoring from a checkpoint, we may set the _initializer_op
532    # property on the entire `DistributedVariable`.
533    self._initializer_op = None
534    # Set a VariablePolicy which decides how we replicate/aggregate the given
535    # variable.
536    self._policy = var_policy
537
538  def __deepcopy__(self, memo):
539    """Perform a deepcopy of the `DistributedVariable`.
540
541    Unlike the deepcopy of a regular tf.Variable, this keeps the original
542    strategy and devices of the `DistributedVariable`.  To avoid confusion
543    with the behavior of deepcopy on a regular `Variable` (which does
544    copy into new devices), we only allow a deepcopy of a `DistributedVariable`
545    within its originating strategy scope.
546
547    Args:
548      memo: The memoization object for `deepcopy`.
549
550    Returns:
551      A deep copy of the current `DistributedVariable`.
552
553    Raises:
554      RuntimeError: If trying to deepcopy into a different strategy.
555    """
556    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
557      new_values = []
558
559      for value in self._values:
560        with ops.device(value.device):
561          new_values.append(copy.deepcopy(value, memo))
562
563    copied_variable = type(self)(
564        strategy=self._distribute_strategy,
565        values=new_values,
566        aggregation=self._aggregation,
567        var_policy=copy.deepcopy(self._policy, memo))
568
569    memo[id(self)] = copied_variable
570
571    return copied_variable
572
573  def _use_packed_variable(self):
574    # Don't use packed variable when under a SaveContext to avoid explicit
575    # device placement on variable consuming ops.
576    return self._packed_var is not None and not save_context.in_save_context()
577
578  def is_initialized(self, name=None):
579    """Identifies if all the component variables are initialized.
580
581    Args:
582      name: Name of the final `logical_and` op.
583
584    Returns:
585      The op that evaluates to True or False depending on if all the
586      component variables are initialized.
587    """
588    if values_util.is_saving_non_distributed():
589      return self._primary.is_initialized()
590    if self._use_packed_variable():
591      return self._packed_var.is_initialized()
592    result = self._primary.is_initialized()
593    # We iterate through the list of values except the last one to allow us to
594    # name the final `logical_and` op the same name that is passed by the user
595    # to the `is_initialized` op. For distributed variables, the
596    # `is_initialized` op is a `logical_and` op.
597    for v in self._values[1:-1]:
598      result = math_ops.logical_and(result, v.is_initialized())
599    result = math_ops.logical_and(
600        result, self._values[-1].is_initialized(), name=name)
601    return result
602
603  @property
604  def initializer(self):
605    if values_util.is_saving_non_distributed():
606      return self._primary.initializer
607    if self._initializer_op:
608      init_op = self._initializer_op
609    else:
610      # return grouped ops of all the var initializations of component values of
611      # the mirrored variable
612      init_op = control_flow_ops.group(
613          tuple(v.initializer for v in self._values))
614    return init_op
615
616  def initialized_value(self):
617    return self._get_on_device_or_primary().initialized_value()
618
619  @property
620  def initial_value(self):
621    return self._get_on_device_or_primary().initial_value
622
623  @property
624  def constraint(self):
625    return self._primary.constraint
626
627  @property
628  def graph(self):
629    return self._primary.graph
630
631  @property
632  def _shared_name(self):
633    return self._common_name
634
635  @property
636  def _unique_id(self):
637    return self._primary._unique_id  # pylint: disable=protected-access
638
639  @property
640  def _graph_key(self):
641    """Lets Optimizers know which graph this variable is from."""
642    return self._primary._graph_key  # pylint: disable=protected-access
643
644  @property
645  def name(self):
646    return self._primary.name
647
648  @property
649  def dtype(self):
650    return self._primary.dtype
651
652  @property
653  def shape(self):
654    return self._primary.shape
655
656  @property
657  def synchronization(self):
658    return self._primary.synchronization
659
660  @property
661  def aggregation(self):
662    return self._aggregation
663
664  @property
665  def _packed_variable(self):
666    if self._use_packed_variable():
667      return self._packed_var
668    return None
669
670  @property
671  def handle(self):
672    if values_util.is_saving_non_distributed():
673      return self._primary.handle
674    replica_id = values_util.get_current_replica_id_as_int()
675    if replica_id is None:
676      raise ValueError(
677          "DistributedVariable.handle is not available outside the replica "
678          "context or a `tf.distribute.Strategy.update()` call.")
679    else:
680      if self._use_packed_variable():
681        return self._packed_var.handle
682      return self._values[replica_id].handle
683
684  def eval(self, session=None):
685    return self._get_on_device_or_primary().eval(session)
686
687  @property
688  def _save_slice_info(self):
689    return self._primary._save_slice_info  # pylint: disable=protected-access
690
691  def _get_save_slice_info(self):
692    return self._primary._get_save_slice_info()  # pylint: disable=protected-access
693
694  def _set_save_slice_info(self, save_slice_info):
695    for v in self._values:
696      v._set_save_slice_info(save_slice_info)  # pylint: disable=protected-access
697
698  @property
699  def device(self):
700    return self._get_on_device_or_primary().device
701
702  @property
703  def trainable(self):
704    return self._primary.trainable
705
706  @property
707  def distribute_strategy(self):
708    return self._distribute_strategy
709
710  def get_shape(self):
711    return self._primary.get_shape()
712
713  def to_proto(self, export_scope=None):
714    return self._primary.to_proto(export_scope=export_scope)
715
716  @property
717  def op(self):
718    if values_util.is_saving_non_distributed():
719      return self._primary.op
720    # We want cross-replica code that does some var.op.X calls
721    # to work (even if the current device isn't in self._devices), but
722    # other uses of var.op in a cross-replica context to fail.
723    if ds_context.in_cross_replica_context():
724      return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
725                              self._primary.op.traceback, self._primary.op.type)
726    return self._get().op
727
728  @property
729  def _in_graph_mode(self):
730    return self._primary._in_graph_mode  # pylint: disable=protected-access
731
732  def _get_replica(self, replica_id):
733    """Returns the value on a device with the given replica_id."""
734    if self._use_packed_variable():
735      return self._packed_var.on_device(self._devices[replica_id])
736    return self._values[replica_id]
737
738  def _get(self):
739    """Returns the value for the current device or raises a ValueError."""
740    if values_util.is_saving_non_distributed():
741      return self._primary
742    replica_id = values_util.get_current_replica_id_as_int()
743    if replica_id is None:
744      return self._get_cross_replica()
745    else:
746      return self._get_replica(replica_id)
747
748  def _get_on_device_or_primary(self):
749    """Returns value in same replica or device if possible, else the _primary."""
750    if values_util.is_saving_non_distributed():
751      return self._primary
752    replica_id = values_util.get_current_replica_id_as_int()
753    if replica_id is None:
754      # Try to find a value on the current device.
755      current_device = device_util.canonicalize(device_util.current())
756      for i, value in enumerate(self._values):
757        if device_util.canonicalize(value.device) == current_device:
758          return self._get_replica(i)
759      return self._get_replica(0)
760    else:
761      return self._get_replica(replica_id)
762
763  def read_value(self):
764    if values_util.is_saving_non_distributed():
765      return self._primary.read_value()
766    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
767      return array_ops.identity(self._get())
768
769  def value(self):
770    if values_util.is_saving_non_distributed():
771      return self._primary.value()
772    if self._policy:
773      return self._policy.value(self)
774    return self._get_on_device_or_primary().value()
775
776  def numpy(self):
777    if context.executing_eagerly():
778      return self.read_value().numpy()
779    else:
780      raise NotImplementedError("DistributedVariable.numpy() is only available "
781                                "when eager execution is enabled.")
782
783  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
784    if values_util.is_saving_non_distributed():
785      return self._primary.assign_sub(value, use_locking, name, read_value)
786    if self._policy:
787      return self._policy.assign_sub(
788          self,
789          value,
790          use_locking=use_locking,
791          name=name,
792          read_value=read_value)
793    return values_util.on_write_assign_sub(
794        self, value, use_locking=use_locking, name=name, read_value=read_value)
795
796  def assign_add(self, value, use_locking=False, name=None, read_value=True):
797    if values_util.is_saving_non_distributed():
798      return self._primary.assign_add(value, use_locking, name, read_value)
799    if self._policy:
800      return self._policy.assign_add(
801          self,
802          value,
803          use_locking=use_locking,
804          name=name,
805          read_value=read_value)
806    return values_util.on_write_assign_add(
807        self, value, use_locking=use_locking, name=name, read_value=read_value)
808
809  def assign(self, value, use_locking=False, name=None, read_value=True):
810    if values_util.is_saving_non_distributed():
811      return self._primary.assign(value, use_locking, name, read_value)
812    if self._policy:
813      return self._policy.assign(
814          self,
815          value,
816          use_locking=use_locking,
817          name=name,
818          read_value=read_value)
819    return values_util.on_write_assign(
820        self, value, use_locking=use_locking, name=name, read_value=read_value)
821
822  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
823    if values_util.is_saving_non_distributed():
824      return self._primary.scatter_sub(sparse_delta, use_locking, name)
825    if self._policy:
826      return self._policy.scatter_sub(
827          self, sparse_delta, use_locking=use_locking, name=name)
828    return values_util.scatter_sub(
829        self, sparse_delta, use_locking=use_locking, name=name)
830
831  def scatter_add(self, sparse_delta, use_locking=False, name=None):
832    if values_util.is_saving_non_distributed():
833      return self._primary.scatter_add(sparse_delta, use_locking, name)
834    if self._policy:
835      return self._policy.scatter_add(
836          self, sparse_delta, use_locking=use_locking, name=name)
837    return values_util.scatter_add(
838        self, sparse_delta, use_locking=use_locking, name=name)
839
840  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
841    if values_util.is_saving_non_distributed():
842      return self._primary.scatter_mul(sparse_delta, use_locking, name)
843    if self._policy:
844      return self._policy.scatter_mul(
845          self, sparse_delta, use_locking=use_locking, name=name)
846    return values_util.scatter_mul(
847        self, sparse_delta, use_locking=use_locking, name=name)
848
849  def scatter_div(self, sparse_delta, use_locking=False, name=None):
850    if values_util.is_saving_non_distributed():
851      return self._primary.scatter_div(sparse_delta, use_locking, name)
852    if self._policy:
853      return self._policy.scatter_div(
854          self, sparse_delta, use_locking=use_locking, name=name)
855    return values_util.scatter_div(
856        self, sparse_delta, use_locking=use_locking, name=name)
857
858  def scatter_min(self, sparse_delta, use_locking=False, name=None):
859    if values_util.is_saving_non_distributed():
860      return self._primary.scatter_min(sparse_delta, use_locking, name)
861    if self._policy:
862      return self._policy.scatter_min(
863          self, sparse_delta, use_locking=use_locking, name=name)
864    return values_util.scatter_min(
865        self, sparse_delta, use_locking=use_locking, name=name)
866
867  def scatter_max(self, sparse_delta, use_locking=False, name=None):
868    if values_util.is_saving_non_distributed():
869      return self._primary.scatter_max(sparse_delta, use_locking, name)
870    if self._policy:
871      return self._policy.scatter_max(
872          self, sparse_delta, use_locking=use_locking, name=name)
873    return values_util.scatter_max(
874        self, sparse_delta, use_locking=use_locking, name=name)
875
876  def scatter_update(self, sparse_delta, use_locking=False, name=None):
877    if values_util.is_saving_non_distributed():
878      return self._primary.scatter_update(sparse_delta, use_locking, name)
879    if self._policy:
880      return self._policy.scatter_update(
881          self, sparse_delta, use_locking=use_locking, name=name)
882    return values_util.scatter_update(
883        self, sparse_delta, use_locking=use_locking, name=name)
884
885  def _gather_saveables_for_checkpoint(self):
886    """Overrides Trackable method.
887
888    This allows both name-based and object-based save and restore of
889    DistributedVariables.
890
891    Returns:
892      A dictionary mapping attribute names to `SaveableObject` factories.
893    """
894
895    def _saveable_factory(name=self._common_name):
896      return _DistributedVariableSaveable(self, self._primary, name)
897
898    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
899
900  def _as_graph_element(self):
901    if values_util.is_saving_non_distributed():
902      return self._primary._as_graph_element()  # pylint: disable=protected-access
903    if self._policy:
904      return self._policy._as_graph_element(self)  # pylint: disable=protected-access
905
906    raise NotImplementedError(
907        "DistributedVariable._as_graph_element requires a valid "
908        "VariablePolicy. Please set the policy via the `var_policy` argument "
909        "in the constructor, or override this method in sub-classes which "
910        "support cross-replica accesses.")
911
912  def _get_cross_replica(self):
913    if values_util.is_saving_non_distributed():
914      return self._primary
915    if self._policy:
916      return self._policy._get_cross_replica(self)  # pylint: disable=protected-access
917
918    raise NotImplementedError(
919        "DistributedVariable._get_cross_replica requires a valid "
920        "VariablePolicy. Please set the policy via the `var_policy` argument "
921        "in the constructor, or override this method in sub-classes which "
922        "support cross-replica accesses.")
923
924  def _update_cross_replica(self, update_fn, value, **kwargs):
925    """Applies updates across replicas.
926
927    Args:
928      update_fn: A callable to pass to `strategy.extended.update` to update the
929        variable. It should has the same signature as `Variable.assign()`.
930      value: value to be passed to `update_fn`.
931      **kwargs: remaining arguments to `update_fn`.
932
933    Returns:
934      Updated variable or `tf.Operation`.
935    """
936    values_util.mark_as_unsaveable()
937    return self.distribute_strategy.extended.update(
938        self, update_fn, args=(value,), kwargs=kwargs, group=True)
939
940  def _update_replica(self, update_fn, value, **kwargs):
941    """Applies updates in one replica.
942
943    Args:
944      update_fn: A callable to update the variable. It should has the same
945        signature as `Variable.assign()`.
946      value: value to be passed to `update_fn`.
947      **kwargs: remaining arguments to `update_fn`.
948
949    Returns:
950      Updated variable or `tf.Operation`.
951    """
952    if self._policy:
953      return self._policy._update_replica(self, update_fn, value, **kwargs)  # pylint: disable=protected-access
954    raise NotImplementedError(
955        "DistributedVariable._update_replica requires a valid VariablePolicy. "
956        "Please set the policy via the `var_policy` argument in the "
957        "constructor, or override this method in sub-classes which support "
958        "cross-replica accesses.")
959
960  def _update(self, update_fn, value, **kwargs):
961    """Applies updates depending on the context.
962
963    The method calls `_update_replica` in replica context,
964    `_update_cross_replica` in cross replica context, and `update_fn` in update
965    context.
966
967    If `read_value` is True, the method returns the updated Variable. If
968    `read_value` is False, the method returns the update `tf.Operation`.
969
970    Args:
971      update_fn: A callable to pass to `strategy.extended.update` to update the
972        variable. It should have the same signature as `Variable.assign()`.
973      value: value to be passed to `update_fn`.
974      **kwargs: keyword arguments to `update_fn`.
975
976    Returns:
977      Updated variable or `tf.Operation`.
978
979    """
980    if values_util.is_saving_non_distributed():
981      return update_fn(self._primary, value, **kwargs)
982    with ds_context.enter_or_assert_strategy(self.distribute_strategy):
983      if ds_context.in_cross_replica_context():
984        update_replica_id = distribute_lib.get_update_replica_id()
985        if update_replica_id is not None:
986          replica_value = self._get_replica(update_replica_id)
987          return update_fn(replica_value, value, **kwargs)
988        return self._update_cross_replica(update_fn, value, **kwargs)
989      else:
990        values_util.assert_replica_context(self.distribute_strategy)
991        return self._update_replica(update_fn, value, **kwargs)
992
993  def _should_act_as_resource_variable(self):
994    """Pass resource_variable_ops.is_resource_variable check."""
995    pass
996
997  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
998    """Converts a variable to a tensor."""
999    if values_util.is_saving_non_distributed():
1000      return ops.convert_to_tensor(
1001          self._primary, dtype=dtype, name=name, as_ref=as_ref)
1002    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1003      return ops.convert_to_tensor(
1004          self._get(), dtype=dtype, name=name, as_ref=as_ref)
1005
1006  def _map_resources(self, save_options):
1007    """For implementing `Trackable`."""
1008    # Initialize for self._primary first, so that obj_map[self._primary] and
1009    # resource_map[self._primary.handle] contain mapped values.
1010    obj_map, resource_map = self._primary._map_resources(save_options)  # pylint:disable=protected-access
1011    for v in [v for v in self._values if v != self._primary]:
1012
1013      if (save_options.experimental_variable_policy  # pylint:disable=protected-access
1014          ._expand_distributed_variables()):
1015        v_obj_map, v_resource_map = v._map_resources(save_options)  # pylint:disable=protected-access
1016        obj_map.update(v_obj_map)
1017        resource_map.update(v_resource_map)
1018      else:
1019        obj_map[v] = obj_map[self._primary]
1020        resource_map[v.handle] = resource_map[self._primary.handle]
1021    obj_map[self] = obj_map[self._primary]
1022    resource_map[self] = resource_map[self._primary.handle]
1023    if self._packed_var is not None:
1024      resource_map[self._packed_var.packed_handle] = resource_map[
1025          self._primary.handle]
1026    return obj_map, resource_map
1027
1028  def _write_object_proto(self, proto, options):
1029    """Update a SavedObject proto for the caller.
1030
1031    If a DistributedVariable object supports this method, it will be called when
1032    saving with a pre-built `SavedObject` proto representing the object, plus an
1033    instance of `SaveOptions`. This method is then free to modify that proto
1034    instance.
1035
1036    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1037    write out information about their components to the
1038    `experimental_distributed_variable_components` field of a
1039    `SavedVariable` (depending on the `SaveOptions` variable policy).
1040
1041    Args:
1042      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1043        will be a `SavedVariable` instance.
1044      options: A `SaveOptions` instance.
1045    """
1046    resource_variable_ops.write_object_proto_for_resource_variable(
1047        self, proto, options)
1048    if self._policy:
1049      if self._policy._is_mirrored():  # pylint: disable=protected-access
1050        self._policy._write_object_proto(self, proto, options)  # pylint: disable=protected-access
1051
1052
1053# We extend from `saveable_object.SaveableObject` instead of
1054# `saveable_object_util.ResourceVariableSaveable` since we need to read the
1055# value of ONREAD variables when saving. `SaveableObject` provides a way to
1056# specify the function to run to get the value of the variable or tensor at
1057# saving time. We can use this for both ON_READ and ON_WRITE variables.
1058# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic
1059# if possible.
1060class _DistributedVariableSaveable(saveable_object.SaveableObject):
1061  """Class for defining how to restore a DistributedVariable."""
1062
1063  def __init__(self, distributed_variable, primary_variable, name):
1064    self._distributed_variable = distributed_variable
1065    if not self._distributed_variable._policy:
1066      raise ValueError(
1067          "The VariablePolicy of the argument `distributed_variable` must be "
1068          "set to create a _DistributedVariableSaveable. Please set it via "
1069          "the `var_policy` argument in the constructor of DistributedVariable."
1070      )
1071    tensor, spec = distributed_variable._policy.get_saveable(
1072        distributed_variable, primary_variable, name)
1073    super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
1074
1075  def restore(self, restored_tensors, restored_shapes):
1076    """Restore the same value into all variables."""
1077    tensor, = restored_tensors
1078    return self._distributed_variable._policy.get_restore_ops(  # pylint: disable=protected-access
1079        self._distributed_variable, tensor)
1080
1081
1082class _MirroredSaveable(saveable_object.SaveableObject):
1083  """Class for defining how to restore a MirroredVariable."""
1084
1085  def __init__(self, mirrored_variable, primary_variable, name):
1086    self._mirrored_variable = mirrored_variable
1087    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
1088                                                     primary_variable, name)
1089    super(_MirroredSaveable, self).__init__(tensor, spec, name)
1090
1091  def restore(self, restored_tensors, restored_shapes):
1092    """Restore the same value into all variables."""
1093    tensor, = restored_tensors
1094    return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)
1095
1096
1097class MirroredVariable(DistributedVariable, Mirrored):
1098  """Holds a map from replica to variables whose values are kept in sync."""
1099
1100  def _update_replica(self, update_fn, value, **kwargs):
1101    return _on_write_update_replica(self, update_fn, value, **kwargs)
1102
1103  def scatter_min(self, *args, **kwargs):
1104    if values_util.is_saving_non_distributed():
1105      return self._primary.scatter_min(*args, **kwargs)
1106    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1107        self._aggregation != vs.VariableAggregation.NONE):
1108      raise NotImplementedError(
1109          values_util.scatter_error_msg.format(
1110              op_name="scatter_min", aggregation=self._aggregation))
1111    return super(MirroredVariable, self).scatter_min(*args, **kwargs)
1112
1113  def scatter_max(self, *args, **kwargs):
1114    if values_util.is_saving_non_distributed():
1115      return self._primary.scatter_max(*args, **kwargs)
1116    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1117        self._aggregation != vs.VariableAggregation.NONE):
1118      raise NotImplementedError(
1119          values_util.scatter_error_msg.format(
1120              op_name="scatter_max", aggregation=self._aggregation))
1121    return super(MirroredVariable, self).scatter_max(*args, **kwargs)
1122
1123  def scatter_update(self, *args, **kwargs):
1124    if values_util.is_saving_non_distributed():
1125      return self._primary.scatter_update(*args, **kwargs)
1126    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1127        self._aggregation != vs.VariableAggregation.NONE):
1128      raise NotImplementedError(
1129          values_util.scatter_error_msg.format(
1130              op_name="scatter_update", aggregation=self._aggregation))
1131    return super(MirroredVariable, self).scatter_update(*args, **kwargs)
1132
1133  def _get_cross_replica(self):
1134    # Return identity, to avoid directly exposing the variable to the user and
1135    # allowing it to be modified by mistake.
1136    return array_ops.identity(Mirrored._get_cross_replica(self))
1137
1138  def _as_graph_element(self):
1139    return self._get_on_device_or_primary()._as_graph_element()  # pylint: disable=protected-access
1140
1141  def _gather_saveables_for_checkpoint(self):
1142    """Overrides Trackable method.
1143
1144    This allows both name-based and object-based save and restore of
1145    MirroredVariables.
1146
1147    Returns:
1148      A dictionary mapping attribute names to `SaveableObject` factories.
1149    """
1150
1151    def _saveable_factory(name=self._common_name):
1152      return _MirroredSaveable(self, self._primary, name)
1153
1154    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1155
1156  def _write_object_proto(self, proto, options):
1157    """Update a SavedObject proto for the caller.
1158
1159    If a DistributedVariable object supports this method, it will be called when
1160    saving with a pre-built `SavedObject` proto representing the object, plus an
1161    instance of `SaveOptions`. This method is then free to modify that proto
1162    instance.
1163
1164    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1165    write out information about their components to the
1166    `experimental_distributed_variable_components` field of a
1167    `SavedVariable` (depending on the `SaveOptions` variable policy).
1168
1169    Args:
1170      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1171        will be a `SavedVariable` instance.
1172      options: A `SaveOptions` instance.
1173    """
1174    super(MirroredVariable, self)._write_object_proto(proto, options)
1175    values_util.write_object_proto(self, proto, options)
1176
1177  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1178    """Converts a variable to a tensor."""
1179    # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ
1180    # and ON_WRITE.
1181    # Try to avoid assignments to and other mutations of MirroredVariable
1182    # state except through a DistributionStrategy.extended.update() or any of
1183    # the `assign*` and `scatter*` calls.
1184    if as_ref:
1185      # A TF 1.x case where the variable is a boolean variable and used like:
1186      # tf.cond(v, true_fn, false_fn).
1187      raise ValueError(
1188          "You may be using variable created under distribute strategy in TF "
1189          "1.x control flows. Try explicitly converting the variable to Tensor "
1190          "using variable.read_value(), or switch to TF 2.x.")
1191    return ops.convert_to_tensor(
1192        self._get(), dtype=dtype, name=name, as_ref=as_ref)
1193
1194
1195class _SyncOnReadSaveable(saveable_object.SaveableObject):
1196  """Class for defining how to restore a SyncOnReadVariable."""
1197
1198  def __init__(self, sync_on_read_variable, name):
1199    self._sync_on_read_variable = sync_on_read_variable
1200    tensor, spec = values_util.get_on_read_saveable(
1201        sync_on_read_variable, sync_on_read_variable._primary, name)
1202
1203    super(_SyncOnReadSaveable, self).__init__(tensor, spec, name)
1204
1205  def restore(self, restored_tensors, restored_shapes):
1206    """Restore the same value into all variables."""
1207    tensor, = restored_tensors
1208    return values_util.get_on_read_restore_ops(
1209        self._sync_on_read_variable, tensor,
1210        self._sync_on_read_variable.aggregation)
1211
1212
1213class SyncOnReadVariable(DistributedVariable):
1214  """Holds a map from replica to variables whose values are reduced on save."""
1215
1216  def _update_replica(self, update_fn, value, **kwargs):
1217    return update_fn(self._get_on_device_or_primary(), value, **kwargs)
1218
1219  def _get(self):
1220    """Returns the value of SyncOnReadVariable based on surrounding context.
1221
1222    If called under a non-default replica-context, returns the corresponding
1223    variable on that replica.
1224    If called under default replica-context or cross-replica context, returns
1225    the synced value.
1226    """
1227    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1228      return super(SyncOnReadVariable, self)._get()
1229
1230  # TODO(b/154017756): Make assign behaivor in cross replica context consistent
1231  # with MirroredVariable.
1232  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
1233    if values_util.is_saving_non_distributed():
1234      return self._primary.assign_sub(value, use_locking, name, read_value)
1235    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1236      if (ds_context.in_cross_replica_context() and
1237          not values_util.in_replica_update_context()):
1238        values_util.mark_as_unsaveable()
1239        return values_util.on_read_assign_sub_cross_replica(
1240            self, value, read_value=read_value)
1241      else:
1242        return super(SyncOnReadVariable,
1243                     self).assign_sub(value, use_locking, name, read_value)
1244
1245  def assign_add(self, value, use_locking=False, name=None, read_value=True):
1246    if values_util.is_saving_non_distributed():
1247      return self._primary.assign_add(value, use_locking, name, read_value)
1248    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1249      if (ds_context.in_cross_replica_context() and
1250          not values_util.in_replica_update_context()):
1251        values_util.mark_as_unsaveable()
1252        return values_util.on_read_assign_add_cross_replica(
1253            self, value, read_value=read_value)
1254      else:
1255        return super(SyncOnReadVariable,
1256                     self).assign_add(value, use_locking, name, read_value)
1257
1258  def assign(self, value, use_locking=False, name=None, read_value=True):
1259    if values_util.is_saving_non_distributed():
1260      return self._primary.assign(value, use_locking, name, read_value)
1261    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1262      if (ds_context.in_cross_replica_context() and
1263          not values_util.in_replica_update_context()):
1264        values_util.mark_as_unsaveable()
1265        return values_util.on_read_assign_cross_replica(
1266            self, value, read_value=read_value)
1267      else:
1268        return super(SyncOnReadVariable, self).assign(value, use_locking, name,
1269                                                      read_value)
1270
1271  def _scatter_not_implemented(self, method):
1272    raise NotImplementedError(
1273        f"Variables with `synchronization=ON_READ` doesn't support `{method}`")
1274
1275  def scatter_sub(self, *args, **kwargs):
1276    if values_util.is_saving_non_distributed():
1277      return self._primary.scatter_sub(*args, **kwargs)
1278    self._scatter_not_implemented("scatter_sub")
1279
1280  def scatter_add(self, *args, **kwargs):
1281    if values_util.is_saving_non_distributed():
1282      return self._primary.scatter_add(*args, **kwargs)
1283    self._scatter_not_implemented("scatter_add")
1284
1285  def scatter_mul(self, *args, **kwargs):
1286    if values_util.is_saving_non_distributed():
1287      return self._primary.scatter_mul(*args, **kwargs)
1288    self._scatter_not_implemented("scatter_mul")
1289
1290  def scatter_div(self, *args, **kwargs):
1291    if values_util.is_saving_non_distributed():
1292      return self._primary.scatter_div(*args, **kwargs)
1293    self._scatter_not_implemented("scatter_div")
1294
1295  def scatter_min(self, *args, **kwargs):
1296    if values_util.is_saving_non_distributed():
1297      return self._primary.scatter_min(*args, **kwargs)
1298    self._scatter_not_implemented("scatter_min")
1299
1300  def scatter_max(self, *args, **kwargs):
1301    if values_util.is_saving_non_distributed():
1302      return self._primary.scatter_max(*args, **kwargs)
1303    self._scatter_not_implemented("scatter_max")
1304
1305  def scatter_update(self, *args, **kwargs):
1306    if values_util.is_saving_non_distributed():
1307      return self._primary.scatter_update(*args, **kwargs)
1308    self._scatter_not_implemented("scatter_update")
1309
1310  def value(self):
1311    if ds_context.in_variable_sync_on_read_context():
1312      raise NotImplementedError(
1313          "call `variable.value()` inside variable_sync_on_read_context is not "
1314          "supported")
1315    if values_util.is_saving_non_distributed():
1316      return self._primary.value()
1317    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1318      if (ds_context.in_cross_replica_context() and
1319          not values_util.in_replica_update_context()):
1320        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1321          return self._get_replica(0).value()
1322        return self._get_cross_replica()
1323      else:
1324        # _get_on_device_or_primary() returns a Variable.
1325        return self._get_on_device_or_primary().value()
1326
1327  def read_value(self):
1328    if ds_context.in_variable_sync_on_read_context():
1329      raise NotImplementedError(
1330          "call `variable.read_value()` inside variable_sync_on_read_context is"
1331          " not supported")
1332    return super().read_value()
1333
1334  def _get_cross_replica(self):
1335    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1336      # Consider returning a tensor value here to make the return value of
1337      # _get_cross_replica consistent.
1338      return self._get_replica(0)
1339    if self._aggregation == vs.VariableAggregation.SUM:
1340      values_util.mark_as_unsaveable()
1341    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1342      return self._distribute_strategy.reduce(
1343          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1344          self,
1345          axis=None)
1346
1347  def _as_graph_element(self):
1348    if values_util.is_saving_non_distributed():
1349      return self._primary._as_graph_element()  # pylint: disable=protected-access
1350    # pylint: disable=protected-access
1351    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1352      if ds_context.in_cross_replica_context():
1353        return ops.convert_to_tensor(self._get_cross_replica())
1354    return self._get()._as_graph_element()
1355
1356  def _gather_saveables_for_checkpoint(self):
1357    """Overrides Trackable method.
1358
1359    This allows both name-based and object-based save and restore of
1360    `SyncOnReadVariable`s.
1361
1362    Returns:
1363      A dictionary mapping attribute names to `SaveableObject` factories.
1364    """
1365
1366    def _saveable_factory(name=self._common_name):
1367      return _SyncOnReadSaveable(self, name)
1368
1369    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1370
1371  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1372    """Converts a SyncOnReadVariable to a tensor."""
1373    if values_util.is_saving_non_distributed():
1374      return ops.convert_to_tensor(
1375          self._primary, dtype=dtype, name=name, as_ref=as_ref)
1376    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1377      replica_context = ds_context.get_replica_context()
1378      if (replica_context is not None and
1379          ds_context.in_variable_sync_on_read_context()):
1380        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1381          return ops.convert_to_tensor(
1382              self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref)
1383        if self._aggregation == vs.VariableAggregation.SUM:
1384          values_util.mark_as_unsaveable()
1385        # pylint: disable=protected-access
1386        reduced = (
1387            replica_context.strategy.extended._replica_ctx_all_reduce(
1388                reduce_util.ReduceOp.from_variable_aggregation(
1389                    self._aggregation),
1390                self._get().read_value()))
1391        return ops.convert_to_tensor(
1392            reduced, dtype=dtype, name=name, as_ref=as_ref)
1393
1394      return ops.convert_to_tensor(
1395          self._get(), dtype=dtype, name=name, as_ref=as_ref)
1396
1397
1398# Register a conversion functions which reads the value of the variable,
1399# allowing instances of the class to be used as tensors.
1400# DistributedVariable
1401def _tensor_conversion_distributed_var(var,
1402                                       dtype=None,
1403                                       name=None,
1404                                       as_ref=False):
1405  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1406
1407
1408ops.register_tensor_conversion_function(DistributedVariable,
1409                                        _tensor_conversion_distributed_var)
1410
1411
1412# MirroredVariables
1413def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
1414  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1415
1416
1417ops.register_tensor_conversion_function(MirroredVariable,
1418                                        _tensor_conversion_mirrored)
1419
1420
1421# Mirrored Values
1422def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False):
1423  return ops.convert_to_tensor(
1424      value._get(), dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1425
1426
1427ops.register_tensor_conversion_function(Mirrored,
1428                                        _tensor_conversion_mirrored_val)
1429
1430
1431# SyncOnReadVariables
1432def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
1433  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1434
1435
1436ops.register_tensor_conversion_function(SyncOnReadVariable,
1437                                        _tensor_conversion_sync_on_read)
1438
1439
1440class VariablePolicy(object):
1441  """Policy defining synchronization and aggregation of a distributed variable.
1442
1443  Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
1444  during variable creation within `tf.distribute` scope, `tf.distribute` creates
1445  an appropriate policy object and assigns it to the distributed variable. All
1446  variable operations are delegated to the respective policy object.
1447  """
1448
1449  def __init__(self, aggregation):
1450    self._aggregation = aggregation
1451
1452  def value(self):
1453    raise NotImplementedError(
1454        "VariablePolicy.value should be overriden by sub-classes.")
1455
1456  def _is_mirrored(self):
1457    raise NotImplementedError(
1458        "VariablePolicy._is_mirrored should be overriden by sub-classes.")
1459
1460  def _as_graph_element(self, _):
1461    raise NotImplementedError(
1462        "VariablePolicy._as_graph_element should be overriden by sub-classes.")
1463
1464  def _get_cross_replica(self, var):
1465    raise NotImplementedError(
1466        "VariablePolicy._get_cross_replica should be overriden by sub-classes.")
1467
1468  def _update_replica(self, var, update_fn, value, **kwargs):
1469    raise NotImplementedError(
1470        "VariablePolicy._update_replica should be overriden by sub-classes.")
1471
1472
1473class OnReadPolicy(VariablePolicy):
1474  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
1475
1476  This policy is created when `synchronization` is set to
1477  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
1478  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
1479  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
1480  scope.
1481  """
1482
1483  def _is_mirrored(self):
1484    return False
1485
1486  def value(self, var):
1487    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1488      if (ds_context.in_cross_replica_context() and
1489          not values_util.in_replica_update_context()):
1490        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1491          return var._get_replica(0).value()  # pylint: disable=protected-access
1492        return var._get_cross_replica()  # pylint: disable=protected-access
1493      else:
1494        return var._get_on_device_or_primary().value()  # pylint: disable=protected-access
1495
1496  def _as_graph_element(self, var):
1497    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1498      if ds_context.in_cross_replica_context():
1499        return ops.convert_to_tensor(var._get_cross_replica())  # pylint: disable=protected-access
1500    return var._get()._as_graph_element()  # pylint: disable=protected-access
1501
1502  def _get_cross_replica(self, var):
1503    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1504      return var._get_replica(0)  # pylint: disable=protected-access
1505    if self._aggregation == vs.VariableAggregation.SUM:
1506      values_util.mark_as_unsaveable()
1507    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1508      return var.distribute_strategy.reduce(
1509          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1510          var,
1511          axis=None)
1512
1513  def _update_replica(self, var, update_fn, value, **kwargs):
1514    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
1515
1516  def _scatter_not_implemented(self, method):
1517    raise NotImplementedError(f"ON_READ variables doesn't support `{method}` "
1518                              "in cross replica context")
1519
1520  def assign_sub(self,
1521                 var,
1522                 value,
1523                 use_locking=False,
1524                 name=None,
1525                 read_value=True):
1526    """Subtracts a value from this variable."""
1527    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1528      if (ds_context.in_cross_replica_context() and
1529          not values_util.in_replica_update_context()):
1530        values_util.mark_as_unsaveable()
1531        return values_util.on_read_assign_sub_cross_replica(
1532            var, value, read_value=read_value)
1533      else:
1534        return values_util.on_write_assign_sub(
1535            var,
1536            value,
1537            use_locking=use_locking,
1538            name=name,
1539            read_value=read_value)
1540
1541  def assign_add(self,
1542                 var,
1543                 value,
1544                 use_locking=False,
1545                 name=None,
1546                 read_value=True):
1547    """Adds a value to this variable."""
1548    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1549      if (ds_context.in_cross_replica_context() and
1550          not values_util.in_replica_update_context()):
1551        values_util.mark_as_unsaveable()
1552        return values_util.on_read_assign_add_cross_replica(
1553            var, value, read_value=read_value)
1554      else:
1555        return values_util.on_write_assign_add(
1556            var,
1557            value,
1558            use_locking=use_locking,
1559            name=name,
1560            read_value=read_value)
1561
1562  def assign(self, var, value, use_locking=False, name=None, read_value=True):
1563    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1564      if (ds_context.in_cross_replica_context() and
1565          not values_util.in_replica_update_context()):
1566        values_util.mark_as_unsaveable()
1567        return values_util.on_read_assign_cross_replica(
1568            var, value, read_value=read_value)
1569      else:
1570        return values_util.on_write_assign(
1571            var,
1572            value,
1573            use_locking=use_locking,
1574            name=name,
1575            read_value=read_value)
1576
1577  def scatter_sub(self, *args, **kwargs):
1578    del args, kwargs
1579    self._scatter_not_implemented("scatter_sub")
1580
1581  def scatter_add(self, *args, **kwargs):
1582    del args, kwargs
1583    self._scatter_not_implemented("scatter_add")
1584
1585  def scatter_mul(self, *args, **kwargs):
1586    del args, kwargs
1587    self._scatter_not_implemented("scatter_mul")
1588
1589  def scatter_div(self, *args, **kwargs):
1590    del args, kwargs
1591    self._scatter_not_implemented("scatter_div")
1592
1593  def scatter_min(self, *args, **kwargs):
1594    del args, kwargs
1595    self._scatter_not_implemented("scatter_min")
1596
1597  def scatter_max(self, *args, **kwargs):
1598    del args, kwargs
1599    self._scatter_not_implemented("scatter_max")
1600
1601  def scatter_update(self, *args, **kwargs):
1602    del args, kwargs
1603    self._scatter_not_implemented("scatter_update")
1604
1605  def get_saveable(self, var, primary_var, name):
1606    """Create a saveable object for the given variable."""
1607    return values_util.get_on_read_saveable(var, primary_var, name)
1608
1609  def get_restore_ops(self, var, tensor):
1610    """Restore the same value into all variables."""
1611    return values_util.get_on_read_restore_ops(var, tensor, self._aggregation)
1612
1613
1614class OnWritePolicy(VariablePolicy):
1615  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
1616
1617  This policy is created when the following `synchronization` and `aggregation`
1618  parameters are specified when creating a `tf.Variable` in `tf.distribute`
1619  scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE`
1620  or `tf.VariableSynchronization.AUTO`.
1621  """
1622
1623  def _is_mirrored(self):
1624    return True
1625
1626  def value(self, var):
1627    return var._get_on_device_or_primary().value()  # pylint: disable=protected-access
1628
1629  def _as_graph_element(self, var):
1630    return var._get_on_device_or_primary()._as_graph_element()  # pylint: disable=protected-access
1631
1632  def _get_cross_replica(self, var):
1633    # Return identity, to avoid directly exposing the variable to the user and
1634    # allowing it to be modified by mistake.
1635    return array_ops.identity(var._get_on_device_or_primary())  # pylint: disable=protected-access
1636
1637  def _update_replica(self, var, update_fn, value, **kwargs):
1638    if var.aggregation == variables_lib.VariableAggregation.NONE:
1639      return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
1640    return _on_write_update_replica(var, update_fn, value, **kwargs)
1641
1642  def assign(self, var, value, use_locking=False, name=None, read_value=True):
1643    return values_util.on_write_assign(
1644        var, value, use_locking=use_locking, name=name, read_value=read_value)
1645
1646  def assign_add(self,
1647                 var,
1648                 value,
1649                 use_locking=False,
1650                 name=None,
1651                 read_value=True):
1652    return values_util.on_write_assign_add(
1653        var, value, use_locking=use_locking, name=name, read_value=read_value)
1654
1655  def assign_sub(self,
1656                 var,
1657                 value,
1658                 use_locking=False,
1659                 name=None,
1660                 read_value=True):
1661    return values_util.on_write_assign_sub(
1662        var, value, use_locking=use_locking, name=name, read_value=read_value)
1663
1664  def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
1665    return values_util.scatter_sub(
1666        var, sparse_delta, use_locking=use_locking, name=name)
1667
1668  def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
1669    return values_util.scatter_add(
1670        var, sparse_delta, use_locking=use_locking, name=name)
1671
1672  def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
1673    return values_util.scatter_mul(
1674        var, sparse_delta, use_locking=use_locking, name=name)
1675
1676  def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
1677    return values_util.scatter_div(
1678        var, sparse_delta, use_locking=use_locking, name=name)
1679
1680  def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
1681    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1682        self._aggregation != vs.VariableAggregation.NONE):
1683      raise NotImplementedError(
1684          values_util.scatter_error_msg.format(
1685              op_name="scatter_min", aggregation=self._aggregation))
1686    return values_util.scatter_min(
1687        var, sparse_delta, use_locking=use_locking, name=name)
1688
1689  def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
1690    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1691        self._aggregation != vs.VariableAggregation.NONE):
1692      raise NotImplementedError(
1693          values_util.scatter_error_msg.format(
1694              op_name="scatter_max", aggregation=self._aggregation))
1695    return values_util.scatter_max(
1696        var, sparse_delta, use_locking=use_locking, name=name)
1697
1698  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
1699    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1700        self._aggregation != vs.VariableAggregation.NONE):
1701      raise NotImplementedError(
1702          values_util.scatter_error_msg.format(
1703              op_name="scatter_update", aggregation=self._aggregation))
1704    return values_util.scatter_update(
1705        var, sparse_delta, use_locking=use_locking, name=name)
1706
1707  def get_saveable(self, var, primary_var, name):
1708    """Saveable ops for AUTO variables."""
1709    return values_util.get_on_write_saveable(var, primary_var, name)
1710
1711  def get_restore_ops(self, var, tensor):
1712    return values_util.get_on_write_restore_ops(var, tensor)
1713
1714  def _write_object_proto(self, var, proto, options):
1715    """Update a SavedObject proto for the caller.
1716
1717    If a DistributedVariable object supports this method, it will be called when
1718    saving with a pre-built `SavedObject` proto representing the object, plus an
1719    instance of `SaveOptions`. This method is then free to modify that proto
1720    instance.
1721
1722    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1723    write out information about their components to the
1724    `experimental_distributed_variable_components` field of a
1725    `SavedVariable` (depending on the `SaveOptions` variable policy).
1726
1727    Args:
1728      var : A DistributedVariable object
1729      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1730        will be a `SavedVariable` instance.
1731      options: A `SaveOptions` instance.
1732    """
1733    values_util.write_object_proto(var, proto, options)
1734
1735
1736class PerWorkerResource():
1737  """A per-worker CachableResource class for non-ParameterServer strategy.
1738
1739  Resources that populate `host_to_resources` should be instances of classes
1740  subclassing CachableResource, although currently it's only used and tested for
1741  StaticHashTable with TPUStrategy.
1742  """
1743
1744  def __init__(self, strategy, host_to_resources):
1745    self._strategy = strategy
1746    self._host_to_resources = host_to_resources
1747
1748  def __getattribute__(self, name):
1749    if name not in ("__init__", "__getattribute__", "_host_to_resources",
1750                    "_strategy", "local_resource"):
1751      return getattr(self.local_resource(), name)
1752    return super(PerWorkerResource, self).__getattribute__(name)
1753
1754  def local_resource(self):
1755    """Returns the resource on the local worker."""
1756    current_device = device_util.canonicalize(device_util.current())
1757    host_device = device_util.canonicalize(
1758        device_util.get_host_for_device(current_device))
1759    return self._host_to_resources.get(
1760        host_device,
1761        self._host_to_resources[next(iter(self._host_to_resources))])
1762