1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values for PS.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import weakref 23 24from tensorflow.python.distribute import distribute_lib 25from tensorflow.python.distribute import distribution_strategy_context as ds_context 26from tensorflow.python.distribute import values 27from tensorflow.python.distribute import values_util 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import variable_scope as vs 30from tensorflow.python.ops import variables as variables_lib 31from tensorflow.python.training.tracking import base as trackable 32from tensorflow.python.types import core 33 34 35# Variable used in PSStrategy TF 1 and CentralStorageStrategy. 36class AggregatingVariable(variables_lib.Variable, core.Tensor): 37 """A wrapper around a variable that aggregates updates across replicas.""" 38 39 def __init__(self, strategy, v, aggregation): 40 self._distribute_strategy = strategy 41 self._v = v 42 # NOTE: We don't use "_distributed_container" here because we don't want 43 # to trigger that code path in regroup(). 44 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access 45 self._aggregation = aggregation 46 47 def __deepcopy__(self, memo): 48 """Perform a deepcopy of the `AggregatingVariable`. 49 50 Unlike the deepcopy of a regular tf.Variable, this keeps the original 51 strategy and devices of the `AggregatingVariable`. To avoid confusion 52 with the behavior of deepcopy on a regular `Variable` (which does 53 copy into new devices), we only allow a deepcopy of a `AggregatingVariable` 54 within its originating strategy scope. 55 56 Args: 57 memo: The memoization object for `deepcopy`. 58 59 Returns: 60 A deep copy of the current `AggregatingVariable`. 61 62 Raises: 63 RuntimeError: If trying to deepcopy into a different strategy. 64 """ 65 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 66 v = copy.deepcopy(self._v, memo) 67 68 copied_variable = type(self)( 69 strategy=self._distribute_strategy, 70 v=v, 71 aggregation=self._aggregation) 72 73 memo[id(self)] = copied_variable 74 75 return copied_variable 76 77 def get(self): 78 return self._v 79 80 @property 81 def distribute_strategy(self): 82 return self._distribute_strategy 83 84 def __getattr__(self, name): 85 return getattr(self._v, name) 86 87 def _assign_func(self, *args, **kwargs): 88 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 89 f = kwargs.pop("f") 90 if ds_context.in_cross_replica_context(): 91 if distribute_lib.get_update_replica_id() is not None: 92 # We are calling an assign function in an update context. 93 return f(self._v, *args, **kwargs) 94 95 # We are calling an assign function in cross replica context, wrap it in 96 # an update call. 97 return self._distribute_strategy.extended.update( 98 self, f, args=args, kwargs=kwargs) 99 else: 100 replica_context = ds_context.get_replica_context() 101 assert replica_context 102 # We are calling an assign function in replica context. 103 # We reduce the value we want to assign/add/sub. More details about how 104 # we handle the different use cases can be found in the _reduce method. 105 # We call the function with the reduced value. 106 if self._aggregation == vs.VariableAggregation.NONE: 107 raise ValueError( 108 values_util.aggregation_error_msg.format( 109 variable_type="AggregatingVariable")) 110 111 def merge_fn(strategy, 112 value, 113 use_locking=False, 114 name=None, 115 read_value=True): 116 v = values_util.apply_aggregation(strategy, value, self._aggregation, 117 self) 118 if name and isinstance(name, values.PerReplica): 119 name = name.values[0] 120 return strategy.extended.update( 121 self, 122 f, 123 args=(v,), 124 kwargs={ 125 "use_locking": use_locking, 126 "name": name, 127 "read_value": read_value 128 }) 129 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 130 131 def assign_sub(self, *args, **kwargs): 132 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 133 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 134 135 def assign_add(self, *args, **kwargs): 136 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 137 return self._assign_func(f=assign_add_fn, *args, **kwargs) 138 139 def assign(self, *args, **kwargs): 140 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 141 return self._assign_func(f=assign_fn, *args, **kwargs) 142 143 @property 144 def initializer(self): 145 return self._v.initializer 146 147 def initialized_value(self): 148 return self._v.initialized_value() 149 150 @property 151 def initial_value(self): 152 return self._v.initial_value 153 154 @property 155 def op(self): 156 return self._v.op 157 158 def read_value(self): 159 return self._v.read_value() 160 161 def eval(self, session=None): 162 return self._v.eval(session) 163 164 @property 165 def graph(self): 166 return self._v.graph 167 168 @property 169 def device(self): 170 return self._v.device 171 172 @property 173 def shape(self): 174 return self._v.shape 175 176 @property 177 def aggregation(self): 178 return self._aggregation 179 180 @property 181 def synchronization(self): 182 return self._v.synchronization 183 184 @property 185 def name(self): 186 return self._v.name 187 188 @property 189 def trainable(self): 190 return self._v.trainable 191 192 @property 193 def dtype(self): 194 return self._v.dtype 195 196 # TODO(josh11b): Test saving & restoring. 197 def _gather_saveables_for_checkpoint(self): 198 return {trackable.VARIABLE_VALUE_KEY: self._v} 199 200 def _map_resources(self, save_options): 201 """For implementing `Trackable`.""" 202 # By delegating this method to the wrapped variable, SavedModel with 203 # AggregatingVariable are identical to SavedModel with normal variables. 204 obj_map, resource_map = self._v._map_resources(save_options) # pylint:disable=protected-access 205 obj_map[self] = obj_map[self._v] 206 return obj_map, resource_map 207 208 # pylint: disable=multiple-statements 209 def __add__(self, o): 210 return self._v + o 211 212 def __radd__(self, o): 213 return o + self._v 214 215 def __sub__(self, o): 216 return self._v - o 217 218 def __rsub__(self, o): 219 return o - self._v 220 221 def __mul__(self, o): 222 return self._v * o 223 224 def __rmul__(self, o): 225 return o * self._v 226 227 def __truediv__(self, o): 228 return self._v / o 229 230 def __rtruediv__(self, o): 231 return o / self._v 232 233 def __floordiv__(self, o): 234 return self._v // o 235 236 def __rfloordiv__(self, o): 237 return o // self._v 238 239 def __mod__(self, o): 240 return self._v % o 241 242 def __rmod__(self, o): 243 return o % self._v 244 245 def __lt__(self, o): 246 return self._v < o 247 248 def __le__(self, o): 249 return self._v <= o 250 251 def __gt__(self, o): 252 return self._v > o 253 254 def __ge__(self, o): 255 return self._v >= o 256 257 def __and__(self, o): 258 return self._v & o 259 260 def __rand__(self, o): 261 return o & self._v 262 263 def __or__(self, o): 264 return self._v | o 265 266 def __ror__(self, o): 267 return o | self._v 268 269 def __xor__(self, o): 270 return self._v ^ o 271 272 def __rxor__(self, o): 273 return o ^ self._v 274 275 def __getitem__(self, o): 276 return self._v[o] 277 278 def __pow__(self, o, modulo=None): 279 return pow(self._v, o, modulo) 280 281 def __rpow__(self, o): 282 return pow(o, self._v) 283 284 def __invert__(self): 285 return ~self._v 286 287 def __neg__(self): 288 return -self._v 289 290 def __abs__(self): 291 return abs(self._v) 292 293 def __div__(self, o): 294 try: 295 return self._v.__div__(o) 296 except AttributeError: 297 # See https://docs.python.org/3/library/constants.html#NotImplemented 298 return NotImplemented 299 300 def __rdiv__(self, o): 301 try: 302 return self._v.__rdiv__(o) 303 except AttributeError: 304 # See https://docs.python.org/3/library/constants.html#NotImplemented 305 return NotImplemented 306 307 def __matmul__(self, o): 308 try: 309 return self._v.__matmul__(o) 310 except AttributeError: 311 # See https://docs.python.org/3/library/constants.html#NotImplemented 312 return NotImplemented 313 314 def __rmatmul__(self, o): 315 try: 316 return self._v.__rmatmul__(o) 317 except AttributeError: 318 # See https://docs.python.org/3/library/constants.html#NotImplemented 319 return NotImplemented 320 321 def __str__(self): 322 return str(self._v) 323 324 def __repr__(self): 325 return repr(self._v) 326 327 def _should_act_as_resource_variable(self): 328 """Pass resource_variable_ops.is_resource_variable check.""" 329 pass 330 331 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 332 return ops.convert_to_tensor(self.get(), dtype=dtype, name=name, 333 as_ref=as_ref) 334 335 336# Register a conversion function which reads the value of the variable, 337# allowing instances of the class to be used as tensors. 338def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): 339 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 340 341 342ops.register_tensor_conversion_function(AggregatingVariable, 343 _tensor_conversion_aggregate) 344