1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import weakref 23 24from tensorflow.python.distribute import device_util 25from tensorflow.python.distribute import tpu_util 26from tensorflow.python.distribute import values_util 27from tensorflow.python.eager import context 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import resource_variable_ops 31from tensorflow.python.ops import variables as variables_lib 32 33 34# pylint: disable=protected-access 35 36 37class _DummyResourceDeleter(object): 38 pass 39 40 41class DistributedVariable(resource_variable_ops.BaseResourceVariable): 42 """Represents variables that are replicated. 43 44 It behaves exactly as a normal variable, but uses corresponding variable 45 handle based on the context. 46 - In each replica, it uses the handle from that replica. 47 - In tpu.replicate(), it uses the replicated handle. 48 - Otherwise, it uses the handle from the primary replica. 49 50 Note that it doesn't synchronize automatically as the old DistributedVariable 51 in values.py. 52 """ 53 54 def __init__(self, variables, *, enable_packed_handle=False): 55 if enable_packed_handle and not ops.executing_eagerly_outside_functions(): 56 raise ValueError( 57 "Argument `enable_packed_handle` is true, but packed handle is only " 58 "supported in eager mode. Please make sure eager execution is " 59 "enabled.") 60 self._variables = variables 61 if enable_packed_handle: 62 self._packed_handle = ops.pack_eager_tensors( 63 [v.handle for v in variables]) 64 else: 65 self._packed_handle = None 66 for v in variables: 67 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 68 self._device_to_handle = {v.device: v.handle for v in variables} 69 self._primary_handle = variables[0].handle 70 with ops.init_scope(), \ 71 ops.name_scope("DistributedVariable", skip_on_eager=False) as name: 72 handle_name = ops.name_from_scope_name(name) 73 self._unique_id = "%s_%d" % (handle_name, ops.uid()) 74 if context.executing_eagerly(): 75 initial_value = None 76 initializer = None 77 else: 78 initial_value = variables[0].initial_value 79 initializer = control_flow_ops.group([v.initializer for v in variables]) 80 super().__init__( 81 trainable=variables[0].trainable, 82 shape=variables[0].shape, 83 dtype=variables[0].dtype, 84 handle=None, 85 synchronization=variables[0].synchronization, 86 constraint=variables[0].constraint, 87 aggregation=variables[0].aggregation, 88 distribute_strategy=variables[0]._distribute_strategy, 89 name=variables[0].name, 90 unique_id=self._unique_id, 91 handle_name=handle_name, 92 graph_element=variables[0]._graph_element, 93 initial_value=initial_value, 94 initializer_op=initializer, 95 is_initialized_op=None, 96 cached_value=None, 97 handle_deleter=_DummyResourceDeleter(), 98 caching_device=None, 99 is_variables=True) 100 101 @property 102 def handle(self): 103 if values_util.is_saving_non_distributed(): 104 return self._primary_handle 105 tpu_context = tpu_util.enclosing_tpu_context() 106 if tpu_context and not context.executing_eagerly(): 107 is_mirrored = ( 108 self._variables[0].synchronization != 109 variables_lib.VariableSynchronization.ON_READ) 110 if self._packed_handle is None: 111 handles = [v.handle for v in self._variables] 112 is_packed = False 113 else: 114 handles = [self._packed_handle] 115 is_packed = True 116 return tpu_context.get_replicated_var_handle(self._unique_id, handles, 117 is_mirrored, is_packed) 118 if self._packed_handle is not None and not context.executing_eagerly(): 119 return self._packed_handle 120 device = device_util.canonicalize(device_util.current()) 121 return self._device_to_handle.get(device, self._primary_handle) 122 123 @property 124 def name(self): 125 if values_util.is_saving_non_distributed(): 126 return self._variables[0].name 127 return super().name 128 129 @property 130 def initializer(self): 131 if values_util.is_saving_non_distributed(): 132 return self._variables[0].initializer 133 return super().initializer 134 135 def _lazy_read(self, op): 136 # Lazy read is not supported. 137 with ops.control_dependencies([op]): 138 return self.read_value() 139 140 # Begin overrides of read/write methods to satisfy the requirement of using 141 # packed handle, i.e. there must be explicit device annotations. 142 143 def _device_scope(self): 144 if (self._packed_handle is None or 145 values_util.is_saving_non_distributed() or 146 tpu_util.enclosing_tpu_context() is not None): 147 return ops.NullContextmanager() 148 device = device_util.canonicalize(device_util.current()) 149 if device in self._device_to_handle: 150 return ops.NullContextmanager() 151 return ops.device(self._primary_handle.device) 152 153 def value(self): 154 # We always force a read_value() instead of using the cached_value, as 155 # value() can be called on different devices. 156 return self.read_value() 157 158 def read_value(self): 159 with self._device_scope(): 160 return super().read_value() 161 162 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 163 with self._device_scope(): 164 return super().assign_sub(delta, use_locking, name, read_value) 165 166 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 167 with self._device_scope(): 168 return super().assign_add(delta, use_locking, name, read_value) 169 170 def assign(self, value, use_locking=None, name=None, read_value=True): 171 with self._device_scope(): 172 return super().assign(value, use_locking, name, read_value) 173 174 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 175 with self._device_scope(): 176 return super().scatter_sub(sparse_delta, use_locking, name) 177 178 def scatter_add(self, sparse_delta, use_locking=False, name=None): 179 with self._device_scope(): 180 return super().scatter_add(sparse_delta, use_locking, name) 181 182 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 183 with self._device_scope(): 184 return super().scatter_mul(sparse_delta, use_locking, name) 185 186 def scatter_div(self, sparse_delta, use_locking=False, name=None): 187 with self._device_scope(): 188 return super().scatter_div(sparse_delta, use_locking, name) 189 190 def scatter_min(self, sparse_delta, use_locking=False, name=None): 191 with self._device_scope(): 192 return super().scatter_min(sparse_delta, use_locking, name) 193 194 def scatter_max(self, sparse_delta, use_locking=False, name=None): 195 with self._device_scope(): 196 return super().scatter_max(sparse_delta, use_locking, name) 197 198 def scatter_update(self, sparse_delta, use_locking=False, name=None): 199 with self._device_scope(): 200 return super().scatter_update(sparse_delta, use_locking, name) 201 202 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 203 with self._device_scope(): 204 return super().batch_scatter_update(sparse_delta, use_locking, name) 205 206 def scatter_nd_sub(self, indices, updates, name=None): 207 with self._device_scope(): 208 return super().scatter_nd_sub(indices, updates, name) 209 210 def scatter_nd_add(self, indices, updates, name=None): 211 with self._device_scope(): 212 return super().scatter_nd_add(indices, updates, name) 213 214 def scatter_nd_update(self, indices, updates, name=None): 215 with self._device_scope(): 216 return super().scatter_nd_update(indices, updates, name) 217 218 def sparse_read(self, indices, name=None): 219 with self._device_scope(): 220 return super().sparse_read(indices, name) 221 222 def gather_nd(self, indices, name=None): 223 with self._device_scope(): 224 return super().gather_nd(indices, name) 225 226 def to_proto(self, export_scope=None): 227 del self 228 raise TypeError("DistributedVariable doesn't support to_proto") 229 230 @staticmethod 231 def from_proto(variable_def, import_scope=None): 232 raise TypeError("DistributedVariable doesn't support from_proto") 233 234 def _as_graph_element(self): 235 if ops.get_default_graph().finalized: 236 return self._variables[0]._graph_element 237 return self.read_value() 238 239 def _strided_slice_assign(self, *args, **kwargs): 240 with self._device_scope(): 241 return super()._strided_slice_assign(*args, **kwargs) 242 243 def __str__(self): 244 debug_str = ",\n".join( 245 " %d: %s" % (i, v) for i, v in enumerate(self._variables)) 246 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 247 248 def __repr__(self): 249 debug_repr = ",\n".join( 250 " %d: %r" % (i, v) for i, v in enumerate(self._variables)) 251 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 252 253 def __deepcopy__(self, memo): 254 copied_variables = copy.deepcopy(self._variables, memo) 255 return DistributedVariable( 256 copied_variables, enable_packed_handle=self._packed_handle is not None) 257 258 259def _tensor_conversion(var, dtype=None, name=None, as_ref=False): 260 if as_ref: 261 raise ValueError( 262 "You may be using variable created under distribute strategy in TF " 263 "1.x control flows. Try explicitly converting the variable to Tensor " 264 "using variable.read_value(), or switch to TF 2.x.") 265 return ops.convert_to_tensor( 266 var.read_value(), dtype=dtype, name=name, as_ref=as_ref) 267 268 269ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) 270