• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import weakref
23
24from tensorflow.python.distribute import device_util
25from tensorflow.python.distribute import tpu_util
26from tensorflow.python.distribute import values_util
27from tensorflow.python.eager import context
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import variables as variables_lib
32
33
34# pylint: disable=protected-access
35
36
37class _DummyResourceDeleter(object):
38  pass
39
40
41class DistributedVariable(resource_variable_ops.BaseResourceVariable):
42  """Represents variables that are replicated.
43
44  It behaves exactly as a normal variable, but uses corresponding variable
45  handle based on the context.
46  - In each replica, it uses the handle from that replica.
47  - In tpu.replicate(), it uses the replicated handle.
48  - Otherwise, it uses the handle from the primary replica.
49
50  Note that it doesn't synchronize automatically as the old DistributedVariable
51  in values.py.
52  """
53
54  def __init__(self, variables, *, enable_packed_handle=False):
55    if enable_packed_handle and not ops.executing_eagerly_outside_functions():
56      raise ValueError(
57          "Argument `enable_packed_handle` is true, but packed handle is only "
58          "supported in eager mode. Please make sure eager execution is "
59          "enabled.")
60    self._variables = variables
61    if enable_packed_handle:
62      self._packed_handle = ops.pack_eager_tensors(
63          [v.handle for v in variables])
64    else:
65      self._packed_handle = None
66    for v in variables:
67      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
68    self._device_to_handle = {v.device: v.handle for v in variables}
69    self._primary_handle = variables[0].handle
70    with ops.init_scope(), \
71         ops.name_scope("DistributedVariable", skip_on_eager=False) as name:
72      handle_name = ops.name_from_scope_name(name)
73      self._unique_id = "%s_%d" % (handle_name, ops.uid())
74      if context.executing_eagerly():
75        initial_value = None
76        initializer = None
77      else:
78        initial_value = variables[0].initial_value
79        initializer = control_flow_ops.group([v.initializer for v in variables])
80      super().__init__(
81          trainable=variables[0].trainable,
82          shape=variables[0].shape,
83          dtype=variables[0].dtype,
84          handle=None,
85          synchronization=variables[0].synchronization,
86          constraint=variables[0].constraint,
87          aggregation=variables[0].aggregation,
88          distribute_strategy=variables[0]._distribute_strategy,
89          name=variables[0].name,
90          unique_id=self._unique_id,
91          handle_name=handle_name,
92          graph_element=variables[0]._graph_element,
93          initial_value=initial_value,
94          initializer_op=initializer,
95          is_initialized_op=None,
96          cached_value=None,
97          handle_deleter=_DummyResourceDeleter(),
98          caching_device=None,
99          is_variables=True)
100
101  @property
102  def handle(self):
103    if values_util.is_saving_non_distributed():
104      return self._primary_handle
105    tpu_context = tpu_util.enclosing_tpu_context()
106    if tpu_context and not context.executing_eagerly():
107      is_mirrored = (
108          self._variables[0].synchronization !=
109          variables_lib.VariableSynchronization.ON_READ)
110      if self._packed_handle is None:
111        handles = [v.handle for v in self._variables]
112        is_packed = False
113      else:
114        handles = [self._packed_handle]
115        is_packed = True
116      return tpu_context.get_replicated_var_handle(self._unique_id, handles,
117                                                   is_mirrored, is_packed)
118    if self._packed_handle is not None and not context.executing_eagerly():
119      return self._packed_handle
120    device = device_util.canonicalize(device_util.current())
121    return self._device_to_handle.get(device, self._primary_handle)
122
123  @property
124  def name(self):
125    if values_util.is_saving_non_distributed():
126      return self._variables[0].name
127    return super().name
128
129  @property
130  def initializer(self):
131    if values_util.is_saving_non_distributed():
132      return self._variables[0].initializer
133    return super().initializer
134
135  def _lazy_read(self, op):
136    # Lazy read is not supported.
137    with ops.control_dependencies([op]):
138      return self.read_value()
139
140  # Begin overrides of read/write methods to satisfy the requirement of using
141  # packed handle, i.e. there must be explicit device annotations.
142
143  def _device_scope(self):
144    if (self._packed_handle is None or
145        values_util.is_saving_non_distributed() or
146        tpu_util.enclosing_tpu_context() is not None):
147      return ops.NullContextmanager()
148    device = device_util.canonicalize(device_util.current())
149    if device in self._device_to_handle:
150      return ops.NullContextmanager()
151    return ops.device(self._primary_handle.device)
152
153  def value(self):
154    # We always force a read_value() instead of using the cached_value, as
155    # value() can be called on different devices.
156    return self.read_value()
157
158  def read_value(self):
159    with self._device_scope():
160      return super().read_value()
161
162  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
163    with self._device_scope():
164      return super().assign_sub(delta, use_locking, name, read_value)
165
166  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
167    with self._device_scope():
168      return super().assign_add(delta, use_locking, name, read_value)
169
170  def assign(self, value, use_locking=None, name=None, read_value=True):
171    with self._device_scope():
172      return super().assign(value, use_locking, name, read_value)
173
174  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
175    with self._device_scope():
176      return super().scatter_sub(sparse_delta, use_locking, name)
177
178  def scatter_add(self, sparse_delta, use_locking=False, name=None):
179    with self._device_scope():
180      return super().scatter_add(sparse_delta, use_locking, name)
181
182  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
183    with self._device_scope():
184      return super().scatter_mul(sparse_delta, use_locking, name)
185
186  def scatter_div(self, sparse_delta, use_locking=False, name=None):
187    with self._device_scope():
188      return super().scatter_div(sparse_delta, use_locking, name)
189
190  def scatter_min(self, sparse_delta, use_locking=False, name=None):
191    with self._device_scope():
192      return super().scatter_min(sparse_delta, use_locking, name)
193
194  def scatter_max(self, sparse_delta, use_locking=False, name=None):
195    with self._device_scope():
196      return super().scatter_max(sparse_delta, use_locking, name)
197
198  def scatter_update(self, sparse_delta, use_locking=False, name=None):
199    with self._device_scope():
200      return super().scatter_update(sparse_delta, use_locking, name)
201
202  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
203    with self._device_scope():
204      return super().batch_scatter_update(sparse_delta, use_locking, name)
205
206  def scatter_nd_sub(self, indices, updates, name=None):
207    with self._device_scope():
208      return super().scatter_nd_sub(indices, updates, name)
209
210  def scatter_nd_add(self, indices, updates, name=None):
211    with self._device_scope():
212      return super().scatter_nd_add(indices, updates, name)
213
214  def scatter_nd_update(self, indices, updates, name=None):
215    with self._device_scope():
216      return super().scatter_nd_update(indices, updates, name)
217
218  def sparse_read(self, indices, name=None):
219    with self._device_scope():
220      return super().sparse_read(indices, name)
221
222  def gather_nd(self, indices, name=None):
223    with self._device_scope():
224      return super().gather_nd(indices, name)
225
226  def to_proto(self, export_scope=None):
227    del self
228    raise TypeError("DistributedVariable doesn't support to_proto")
229
230  @staticmethod
231  def from_proto(variable_def, import_scope=None):
232    raise TypeError("DistributedVariable doesn't support from_proto")
233
234  def _as_graph_element(self):
235    if ops.get_default_graph().finalized:
236      return self._variables[0]._graph_element
237    return self.read_value()
238
239  def _strided_slice_assign(self, *args, **kwargs):
240    with self._device_scope():
241      return super()._strided_slice_assign(*args, **kwargs)
242
243  def __str__(self):
244    debug_str = ",\n".join(
245        "  %d: %s" % (i, v) for i, v in enumerate(self._variables))
246    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
247
248  def __repr__(self):
249    debug_repr = ",\n".join(
250        "  %d: %r" % (i, v) for i, v in enumerate(self._variables))
251    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
252
253  def __deepcopy__(self, memo):
254    copied_variables = copy.deepcopy(self._variables, memo)
255    return DistributedVariable(
256        copied_variables, enable_packed_handle=self._packed_handle is not None)
257
258
259def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
260  if as_ref:
261    raise ValueError(
262        "You may be using variable created under distribute strategy in TF "
263        "1.x control flows. Try explicitly converting the variable to Tensor "
264        "using variable.read_value(), or switch to TF 2.x.")
265  return ops.convert_to_tensor(
266      var.read_value(), dtype=dtype, name=name, as_ref=as_ref)
267
268
269ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
270