1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values for PS.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import weakref 23 24import numpy as np 25 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import distribute_utils 28from tensorflow.python.distribute import distribution_strategy_context as ds_context 29from tensorflow.python.distribute import values 30from tensorflow.python.distribute import values_util 31from tensorflow.python.eager import context 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import variable_scope as vs 36from tensorflow.python.training.tracking import base as trackable 37from tensorflow.python.types import core 38 39 40# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy. 41class AggregatingVariable(resource_variable_ops.BaseResourceVariable, 42 core.Tensor): 43 """A wrapper around a variable that aggregates updates across replicas.""" 44 45 def __init__(self, strategy, v, aggregation): 46 self._distribute_strategy = strategy 47 self._v = v 48 # NOTE: We don't use "_distributed_container" here because we don't want 49 # to trigger that code path in regroup(). 50 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access 51 self._aggregation = aggregation 52 53 def __deepcopy__(self, memo): 54 """Perform a deepcopy of the `AggregatingVariable`. 55 56 Unlike the deepcopy of a regular tf.Variable, this keeps the original 57 strategy and devices of the `AggregatingVariable`. To avoid confusion 58 with the behavior of deepcopy on a regular `Variable` (which does 59 copy into new devices), we only allow a deepcopy of a `AggregatingVariable` 60 within its originating strategy scope. 61 62 Args: 63 memo: The memoization object for `deepcopy`. 64 65 Returns: 66 A deep copy of the current `AggregatingVariable`. 67 68 Raises: 69 RuntimeError: If trying to deepcopy into a different strategy. 70 """ 71 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 72 v = copy.deepcopy(self._v, memo) 73 74 copied_variable = type(self)( 75 strategy=self._distribute_strategy, 76 v=v, 77 aggregation=self._aggregation) 78 79 memo[id(self)] = copied_variable 80 81 return copied_variable 82 83 def get(self): 84 return self._v 85 86 @property 87 def distribute_strategy(self): 88 return self._distribute_strategy 89 90 def __getattr__(self, name): 91 return getattr(self._v, name) 92 93 def _assign_func(self, *args, **kwargs): 94 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 95 f = kwargs.pop("f") 96 if ds_context.in_cross_replica_context(): 97 if distribute_lib.get_update_replica_id() is not None: 98 # We are calling an assign function in an update context. 99 return f(self._v, *args, **kwargs) 100 101 # We are calling an assign function in cross replica context, wrap it in 102 # an update call. 103 return self._distribute_strategy.extended.update( 104 self, f, args=args, kwargs=kwargs) 105 else: 106 replica_context = ds_context.get_replica_context() 107 assert replica_context 108 # We are calling an assign function in replica context. 109 # We reduce the value we want to assign/add/sub. More details about how 110 # we handle the different use cases can be found in the _reduce method. 111 # We call the function with the reduced value. 112 if self._aggregation == vs.VariableAggregation.NONE: 113 raise ValueError( 114 values_util.aggregation_error_msg.format( 115 variable_type="AggregatingVariable")) 116 117 def merge_fn(strategy, 118 value, 119 use_locking=False, 120 name=None, 121 read_value=True): 122 v = values_util.apply_aggregation(strategy, value, self._aggregation, 123 self) 124 if name and isinstance(name, values.PerReplica): 125 name = name.values[0] 126 return strategy.extended.update( 127 self, 128 f, 129 args=(v,), 130 kwargs={ 131 "use_locking": use_locking, 132 "name": name, 133 "read_value": read_value 134 }) 135 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 136 137 def assign_sub(self, *args, **kwargs): 138 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 139 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 140 141 def assign_add(self, *args, **kwargs): 142 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 143 return self._assign_func(f=assign_add_fn, *args, **kwargs) 144 145 def assign(self, *args, **kwargs): 146 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 147 return self._assign_func(f=assign_fn, *args, **kwargs) 148 149 @property 150 def initializer(self): 151 return self._v.initializer 152 153 def initialized_value(self): 154 return self._v.initialized_value() 155 156 @property 157 def initial_value(self): 158 return self._v.initial_value 159 160 @property 161 def op(self): 162 return self._v.op 163 164 def value(self): 165 return self._v.value() 166 167 def read_value(self): 168 return self._v.read_value() 169 170 def sparse_read(self, indices, name=None): 171 return self._v.sparse_read(indices, name=name) 172 173 def eval(self, session=None): 174 return self._v.eval(session) 175 176 @property 177 def graph(self): 178 return self._v.graph 179 180 @property 181 def device(self): 182 return self._v.device 183 184 @property 185 def shape(self): 186 return self._v.shape 187 188 @property 189 def aggregation(self): 190 return self._aggregation 191 192 @property 193 def synchronization(self): 194 return self._v.synchronization 195 196 @property 197 def name(self): 198 return self._v.name 199 200 @property 201 def trainable(self): 202 return self._v.trainable 203 204 @property 205 def dtype(self): 206 return self._v.dtype 207 208 # TODO(josh11b): Test saving & restoring. 209 def _gather_saveables_for_checkpoint(self): 210 if isinstance(self._v, CachingVariable): 211 return self._v._gather_saveables_for_checkpoint() # pylint:disable=protected-access 212 return {trackable.VARIABLE_VALUE_KEY: self._v} 213 214 def _map_resources(self, save_options): 215 """For implementing `Trackable`.""" 216 # By delegating this method to the wrapped variable, SavedModel with 217 # AggregatingVariable are identical to SavedModel with normal variables. 218 obj_map, resource_map = self._v._map_resources(save_options) # pylint:disable=protected-access 219 obj_map[self] = obj_map[self._v] 220 return obj_map, resource_map 221 222 # pylint: disable=multiple-statements 223 def __add__(self, o): 224 return self._v + o 225 226 def __radd__(self, o): 227 return o + self._v 228 229 def __sub__(self, o): 230 return self._v - o 231 232 def __rsub__(self, o): 233 return o - self._v 234 235 def __mul__(self, o): 236 return self._v * o 237 238 def __rmul__(self, o): 239 return o * self._v 240 241 def __truediv__(self, o): 242 return self._v / o 243 244 def __rtruediv__(self, o): 245 return o / self._v 246 247 def __floordiv__(self, o): 248 return self._v // o 249 250 def __rfloordiv__(self, o): 251 return o // self._v 252 253 def __mod__(self, o): 254 return self._v % o 255 256 def __rmod__(self, o): 257 return o % self._v 258 259 def __lt__(self, o): 260 return self._v < o 261 262 def __le__(self, o): 263 return self._v <= o 264 265 def __gt__(self, o): 266 return self._v > o 267 268 def __ge__(self, o): 269 return self._v >= o 270 271 def __and__(self, o): 272 return self._v & o 273 274 def __rand__(self, o): 275 return o & self._v 276 277 def __or__(self, o): 278 return self._v | o 279 280 def __ror__(self, o): 281 return o | self._v 282 283 def __xor__(self, o): 284 return self._v ^ o 285 286 def __rxor__(self, o): 287 return o ^ self._v 288 289 def __getitem__(self, o): 290 return self._v[o] 291 292 def __pow__(self, o, modulo=None): 293 return pow(self._v, o, modulo) 294 295 def __rpow__(self, o): 296 return pow(o, self._v) 297 298 def __invert__(self): 299 return ~self._v 300 301 def __neg__(self): 302 return -self._v 303 304 def __abs__(self): 305 return abs(self._v) 306 307 def __div__(self, o): 308 try: 309 return self._v.__div__(o) 310 except AttributeError: 311 # See https://docs.python.org/3/library/constants.html#NotImplemented 312 return NotImplemented 313 314 def __rdiv__(self, o): 315 try: 316 return self._v.__rdiv__(o) 317 except AttributeError: 318 # See https://docs.python.org/3/library/constants.html#NotImplemented 319 return NotImplemented 320 321 def __matmul__(self, o): 322 try: 323 return self._v.__matmul__(o) 324 except AttributeError: 325 # See https://docs.python.org/3/library/constants.html#NotImplemented 326 return NotImplemented 327 328 def __rmatmul__(self, o): 329 try: 330 return self._v.__rmatmul__(o) 331 except AttributeError: 332 # See https://docs.python.org/3/library/constants.html#NotImplemented 333 return NotImplemented 334 335 def __str__(self): 336 return str(self._v) 337 338 def __repr__(self): 339 return repr(self._v) 340 341 def _should_act_as_resource_variable(self): 342 """Pass resource_variable_ops.is_resource_variable check.""" 343 pass 344 345 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 346 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 347 348 349class CachingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor): 350 """A wrapper around a variable that caches read value locally.""" 351 352 def __init__(self, v): 353 self._v = v 354 self._cache = None 355 self._current_new_cache_scope_count = 0 356 357 def get(self): 358 return self._v 359 360 def __getattr__(self, name): 361 return getattr(self._v, name) 362 363 def read_value(self): 364 if distribute_utils.caching_scope_local.in_caching_scope(): 365 return self.cached_read_value() 366 return self._v.read_value() 367 368 def sparse_read(self, indices, name=None): 369 return self._v.sparse_read(indices, name=name) 370 371 def cached_read_value(self): 372 if (distribute_utils.caching_scope_local.new_cache_scope_count > 373 self._current_new_cache_scope_count): 374 self._current_new_cache_scope_count += 1 375 self._cache = None 376 377 with ops.device("CPU:0"): 378 if self._cache is not None: 379 return self._cache 380 else: 381 self._cache = array_ops.identity(self._v) 382 return self._cache 383 384 def assign_sub(self, *args, **kwargs): 385 return self._v.assign_sub(*args, **kwargs) 386 387 def assign_add(self, *args, **kwargs): 388 return self._v.assign_add(*args, **kwargs) 389 390 def assign(self, *args, **kwargs): 391 return self._v.assign(*args, **kwargs) 392 393 @property 394 def initializer(self): 395 return self._v.initializer 396 397 def initialized_value(self): 398 return self._v.initialized_value() 399 400 @property 401 def initial_value(self): 402 return self._v.initial_value 403 404 @property 405 def op(self): 406 return self._v.op 407 408 def value(self): 409 if distribute_utils.caching_scope_local.in_caching_scope(): 410 return self.cached_read_value() 411 return self._v.value() 412 413 def eval(self, session=None): 414 return self._v.eval(session) 415 416 @property 417 def graph(self): 418 return self._v.graph 419 420 @property 421 def device(self): 422 return self._v.device 423 424 @property 425 def shape(self): 426 return self._v.shape 427 428 @property 429 def synchronization(self): 430 return self._v.synchronization 431 432 @property 433 def name(self): 434 return self._v.name 435 436 @property 437 def trainable(self): 438 return self._v.trainable 439 440 @property 441 def dtype(self): 442 return self._v.dtype 443 444 @property 445 def constraint(self): 446 return self._v.constraint 447 448 def __array__(self): 449 return np.asarray(self.numpy()) 450 451 def __complex__(self): 452 return complex(self.value().numpy()) 453 454 def __int__(self): 455 return int(self.value().numpy()) 456 457 def __float__(self): 458 return float(self.value().numpy()) 459 460 def numpy(self): 461 if context.executing_eagerly(): 462 return self.read_value().numpy() 463 else: 464 raise NotImplementedError( 465 "numpy() is only available when eager execution is enabled.") 466 467 def __str__(self): 468 return str(self._v) 469 470 def __repr__(self): 471 return repr(self._v) 472 473 def _should_act_as_resource_variable(self): 474 """Pass resource_variable_ops.is_resource_variable check.""" 475 pass 476 477 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 478 if distribute_utils.caching_scope_local.in_caching_scope(): 479 return self.cached_read_value() 480 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=False) # pylint: disable=protected-access 481 482 @classmethod 483 def _overload_overloadable_operators(cls): 484 """Register overloads for all operators.""" 485 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 486 # Overloading __eq__ or __ne__ does not work as expected. 487 if operator == "__eq__" or operator == "__ne__": 488 continue 489 cls._tensor_overload_operator(operator) 490 491 @classmethod 492 def _tensor_overload_operator(cls, operator): 493 """Delegate an operator overload to `ops.Tensor`.""" 494 tensor_operator = getattr(ops.Tensor, operator) 495 496 def _operator(v, *args, **kwargs): 497 return tensor_operator(v.value(), *args, **kwargs) # pylint: disable=protected-access 498 setattr(cls, operator, _operator) 499 500 def _gather_saveables_for_checkpoint(self): 501 return {trackable.VARIABLE_VALUE_KEY: self._v} 502 503 def _map_resources(self, save_options): 504 """For implementing `Trackable`.""" 505 # By delegating this method to the wrapped variable, SavedModel with 506 # AggregatingVariable are identical to SavedModel with normal variables. 507 obj_map, resource_map = self._v._map_resources(save_options) # pylint:disable=protected-access 508 obj_map[self] = obj_map[self._v] 509 return obj_map, resource_map 510 511 512# Register a conversion function which reads the value of the variable, 513# allowing instances of the class to be used as tensors. 514def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): 515 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 516 517 518ops.register_tensor_conversion_function(AggregatingVariable, 519 _tensor_conversion_aggregate) 520 521 522# Register a conversion function which reads the value of the variable, 523# allowing instances of the class to be used as tensors. 524def _tensor_conversion_caching(var, dtype=None, name=None, as_ref=False): 525 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 526 527 528ops.register_tensor_conversion_function(CachingVariable, 529 _tensor_conversion_caching) 530 531CachingVariable._overload_overloadable_operators() # pylint: disable=protected-access 532