1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing TPU distributed values. 16 17Note that the tests are in values_test.py . 18 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import contextlib 26 27from tensorflow.python.distribute import packed_distributed_variable as packed 28from tensorflow.python.distribute import tpu_util 29from tensorflow.python.distribute import values 30from tensorflow.python.distribute import values_util 31from tensorflow.python.eager import context 32from tensorflow.python.eager import tape 33from tensorflow.python.framework import ops 34from tensorflow.python.ops import gen_resource_variable_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import variable_scope 37 38 39@contextlib.contextmanager 40def _maybe_enter_graph(tensor): 41 # Note: might have an eager tensor but not be executing eagerly when 42 # building functions. 43 if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or 44 ops.has_default_graph()): 45 yield 46 else: 47 with tensor.graph.as_default(): 48 yield 49 50 51@contextlib.contextmanager 52def _maybe_on_device(var): 53 # Add a device scope for packed variables. 54 if isinstance(var, packed.PackedVarAndDevice): 55 with ops.device(var.device): 56 yield 57 else: 58 yield 59 60 61def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring 62 63 def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring 64 del use_locking # Unused. 65 66 handle = var.handle 67 with _maybe_enter_graph(handle), _maybe_on_device(var): 68 op = raw_assign_fn( 69 handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) 70 with ops.control_dependencies([op]): 71 return var._read_variable_op() if read_value else op # pylint: disable=protected-access 72 73 return assign_fn 74 75 76class TPUVariableMixin(object): 77 """Mixin for TPU variables.""" 78 79 def __init__(self, *args, **kwargs): 80 super(TPUVariableMixin, self).__init__(*args, **kwargs) 81 82 # Handle ID is needed for `get_replicated_var_handle` to cache the variables 83 # correctly since in eager mode different variables can have the same name. 84 if ops.executing_eagerly_outside_functions(): 85 self._handle_id = self._common_name + "_" + str(id(self._primary)) 86 else: 87 self._handle_id = self._common_name 88 89 def __getattr__(self, name): 90 if tpu_util.enclosing_tpu_context() is None: 91 return super(TPUVariableMixin, self).__getattr__(name) 92 else: 93 raise AttributeError( 94 "'{}' not accessible within a TPU context.".format(name)) 95 96 def get(self): 97 if tpu_util.enclosing_tpu_context() is None: 98 return super(TPUVariableMixin, self).get() 99 else: 100 raise NotImplementedError( 101 "`TPUVariableMixin.get()` is not supported within a TPU context.") 102 103 def _get_as_operand(self): 104 return self.read_value() 105 106 def _is_mirrored(self): 107 raise NotImplementedError( 108 "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") 109 110 @property 111 def handle(self): 112 """The handle by which this variable can be accessed.""" 113 # If we're in a tpu.rewrite(), return the replicated handle. 114 tpu_context = tpu_util.enclosing_tpu_context() 115 if tpu_context is None or context.executing_eagerly(): 116 var = self._get_on_device_or_primary() 117 if isinstance(var, packed.PackedVarAndDevice): 118 return var.on_device_handle() 119 else: 120 return var.handle 121 else: 122 is_packed = self._packed_var is not None 123 val = self._values 124 if is_packed: 125 val = [self._packed_var] 126 127 return tpu_context.get_replicated_var_handle(self._handle_id, val, 128 self._is_mirrored(), 129 is_packed) 130 131 @property 132 def device(self): 133 return self.handle.device 134 135 def _read_variable_op(self): 136 """Reads the value of this variable.""" 137 if self.trainable: 138 tape.variable_accessed(self) 139 140 handle = self.handle 141 if getattr(handle, "is_packed", False): 142 # Add a device scope for a packed variable handle. 143 with ops.device(self._get_on_device_or_primary().device): 144 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 145 else: 146 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 147 148 def read_value(self): 149 if tpu_util.enclosing_tpu_context() is None: 150 return super(TPUVariableMixin, self).read_value() 151 else: 152 return self._read_variable_op() 153 154 def value(self): 155 if tpu_util.enclosing_tpu_context() is None: 156 return super(TPUVariableMixin, self).value() 157 else: 158 return self._read_variable_op() 159 160 def _as_graph_element(self): 161 if tpu_util.enclosing_tpu_context() is None: 162 return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access 163 else: 164 return None 165 166 @property 167 def op(self): 168 if values_util.is_saving_non_distributed(): 169 return self._primary.op 170 return values.DistributedVarOp(self._primary.op.name, 171 self._primary.op.graph, 172 self._primary.op.traceback, 173 self._primary.op.type) 174 175 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 176 """Converts a variable to a tensor.""" 177 # pylint: disable=protected-access 178 if tpu_util.enclosing_tpu_context() is None: 179 return super(TPUVariableMixin, self)._dense_var_to_tensor( 180 dtype=dtype, name=name, as_ref=as_ref) 181 # pylint: enable=protected-access 182 elif dtype is not None and dtype != self.dtype: 183 return math_ops.cast(self.read_value(), dtype) 184 else: 185 return self.handle if as_ref else self.read_value() 186 187 188class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): 189 """DistributedVariable subclass for TPUStrategy.""" 190 191 def _is_mirrored(self): 192 return self._policy._is_mirrored() # pylint: disable=protected-access 193 194 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 195 if values_util.is_saving_non_distributed(): 196 return self._primary.assign_sub(value, use_locking, name, read_value) 197 return self._policy.assign_sub( 198 self, value, use_locking=use_locking, name=name, read_value=read_value) 199 200 def assign_add(self, value, use_locking=False, name=None, read_value=True): 201 if values_util.is_saving_non_distributed(): 202 return self._primary.assign_add(value, use_locking, name, read_value) 203 return self._policy.assign_add( 204 self, value, use_locking=use_locking, name=name, read_value=read_value) 205 206 def assign(self, value, use_locking=False, name=None, read_value=True): 207 if values_util.is_saving_non_distributed(): 208 return self._primary.assign(value, use_locking, name, read_value) 209 return self._policy.assign( 210 self, value, use_locking=use_locking, name=name, read_value=read_value) 211 212 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 213 if values_util.is_saving_non_distributed(): 214 return self._primary.scatter_sub(sparse_delta, use_locking, name) 215 return self._policy.scatter_sub( 216 self, sparse_delta, use_locking=use_locking, name=name) 217 218 def scatter_add(self, sparse_delta, use_locking=False, name=None): 219 if values_util.is_saving_non_distributed(): 220 return self._primary.scatter_add(sparse_delta, use_locking, name) 221 return self._policy.scatter_add( 222 self, sparse_delta, use_locking=use_locking, name=name) 223 224 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 225 if values_util.is_saving_non_distributed(): 226 return self._primary.scatter_mul(sparse_delta, use_locking, name) 227 return self._policy.scatter_mul( 228 self, sparse_delta, use_locking=use_locking, name=name) 229 230 def scatter_div(self, sparse_delta, use_locking=False, name=None): 231 if values_util.is_saving_non_distributed(): 232 return self._primary.scatter_div(sparse_delta, use_locking, name) 233 return self._policy.scatter_div( 234 self, sparse_delta, use_locking=use_locking, name=name) 235 236 def scatter_min(self, sparse_delta, use_locking=False, name=None): 237 if values_util.is_saving_non_distributed(): 238 return self._primary.scatter_min(sparse_delta, use_locking, name) 239 return self._policy.scatter_min( 240 self, sparse_delta, use_locking=use_locking, name=name) 241 242 def scatter_max(self, sparse_delta, use_locking=False, name=None): 243 if values_util.is_saving_non_distributed(): 244 return self._primary.scatter_max(sparse_delta, use_locking, name) 245 return self._policy.scatter_max( 246 self, sparse_delta, use_locking=use_locking, name=name) 247 248 def scatter_update(self, sparse_delta, use_locking=False, name=None): 249 if values_util.is_saving_non_distributed(): 250 return self._primary.scatter_update(sparse_delta, use_locking, name) 251 return self._policy.scatter_update( 252 self, sparse_delta, use_locking=use_locking, name=name) 253 254 255class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): 256 """Holds a map from replica to TPU variables whose values are kept in sync.""" 257 258 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 259 if (tpu_util.enclosing_tpu_context() and 260 self.aggregation == variable_scope.VariableAggregation.NONE): 261 return _make_raw_assign_fn( 262 gen_resource_variable_ops.assign_sub_variable_op)( 263 self, 264 value=value, 265 use_locking=use_locking, 266 name=name, 267 read_value=read_value) 268 return assign_sub( 269 self, value, use_locking=use_locking, name=name, read_value=read_value) 270 271 def assign_add(self, value, use_locking=False, name=None, read_value=True): 272 if (tpu_util.enclosing_tpu_context() and 273 self.aggregation == variable_scope.VariableAggregation.NONE): 274 return _make_raw_assign_fn( 275 gen_resource_variable_ops.assign_add_variable_op)( 276 self, 277 value=value, 278 use_locking=use_locking, 279 name=name, 280 read_value=read_value) 281 return assign_add( 282 self, value, use_locking=use_locking, name=name, read_value=read_value) 283 284 def assign(self, value, use_locking=False, name=None, read_value=True): 285 if (tpu_util.enclosing_tpu_context() and 286 self.aggregation == variable_scope.VariableAggregation.NONE): 287 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 288 self, 289 value=value, 290 use_locking=use_locking, 291 name=name, 292 read_value=read_value) 293 return assign( 294 self, value, use_locking=use_locking, name=name, read_value=read_value) 295 296 def scatter_sub(self, *args, **kwargs): 297 if values_util.is_saving_non_distributed(): 298 return self._primary.scatter_sub(*args, **kwargs) 299 raise NotImplementedError 300 301 def scatter_add(self, *args, **kwargs): 302 if values_util.is_saving_non_distributed(): 303 return self._primary.scatter_add(*args, **kwargs) 304 raise NotImplementedError 305 306 def scatter_max(self, *args, **kwargs): 307 if values_util.is_saving_non_distributed(): 308 return self._primary.scatter_max(*args, **kwargs) 309 raise NotImplementedError 310 311 def scatter_min(self, *args, **kwargs): 312 if values_util.is_saving_non_distributed(): 313 return self._primary.scatter_min(*args, **kwargs) 314 raise NotImplementedError 315 316 def scatter_mul(self, *args, **kwargs): 317 if values_util.is_saving_non_distributed(): 318 return self._primary.scatter_mul(*args, **kwargs) 319 raise NotImplementedError 320 321 def scatter_div(self, *args, **kwargs): 322 if values_util.is_saving_non_distributed(): 323 return self._primary.scatter_div(*args, **kwargs) 324 raise NotImplementedError 325 326 def scatter_update(self, *args, **kwargs): 327 if values_util.is_saving_non_distributed(): 328 return self._primary.scatter_update(*args, **kwargs) 329 raise NotImplementedError 330 331 def _is_mirrored(self): 332 return True 333 334 335class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): 336 """Holds a map from replica to variables whose values are reduced on save.""" 337 338 def assign_sub(self, *args, **kwargs): 339 if tpu_util.enclosing_tpu_context() is None: 340 return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) 341 else: 342 return _make_raw_assign_fn( 343 gen_resource_variable_ops.assign_sub_variable_op)(self, *args, 344 **kwargs) 345 346 def assign_add(self, *args, **kwargs): 347 if tpu_util.enclosing_tpu_context() is None: 348 return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) 349 else: 350 return _make_raw_assign_fn( 351 gen_resource_variable_ops.assign_add_variable_op)(self, *args, 352 **kwargs) 353 354 def assign(self, *args, **kwargs): 355 if tpu_util.enclosing_tpu_context() is None: 356 return values.SyncOnReadVariable.assign(self, *args, **kwargs) 357 else: 358 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 359 self, *args, **kwargs) 360 361 def _is_mirrored(self): 362 return False 363 364 365# Common method between OnWrite and Mirrored variables. 366def assign_sub(var, value, use_locking=False, name=None, read_value=True): 367 assign_sub_fn = _make_raw_assign_fn( 368 gen_resource_variable_ops.assign_sub_variable_op) 369 return var._update( # pylint: disable=protected-access 370 update_fn=assign_sub_fn, 371 value=value, 372 use_locking=use_locking, 373 name=name, 374 read_value=read_value) 375 376 377def assign_add(var, value, use_locking=False, name=None, read_value=True): 378 assign_add_fn = _make_raw_assign_fn( 379 gen_resource_variable_ops.assign_add_variable_op) 380 return var._update( # pylint: disable=protected-access 381 update_fn=assign_add_fn, 382 value=value, 383 use_locking=use_locking, 384 name=name, 385 read_value=read_value) 386 387 388def assign(var, value, use_locking=False, name=None, read_value=True): 389 assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op) 390 return var._update( # pylint: disable=protected-access 391 update_fn=assign_fn, 392 value=value, 393 use_locking=use_locking, 394 name=name, 395 read_value=read_value) 396 397 398class TPUOnWritePolicy(values.OnWritePolicy): 399 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 400 401 This policy is created when `synchronization` is set to 402 `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`. 403 """ 404 405 def assign_sub(self, 406 var, 407 value, 408 use_locking=False, 409 name=None, 410 read_value=True): 411 if (tpu_util.enclosing_tpu_context() and 412 var.aggregation == variable_scope.VariableAggregation.NONE): 413 return _make_raw_assign_fn( 414 gen_resource_variable_ops.assign_sub_variable_op)( 415 var, 416 value=value, 417 use_locking=use_locking, 418 name=name, 419 read_value=read_value) 420 return assign_sub( 421 var, value, use_locking=use_locking, name=name, read_value=read_value) 422 423 def assign_add(self, 424 var, 425 value, 426 use_locking=False, 427 name=None, 428 read_value=True): 429 if (tpu_util.enclosing_tpu_context() and 430 var.aggregation == variable_scope.VariableAggregation.NONE): 431 return _make_raw_assign_fn( 432 gen_resource_variable_ops.assign_add_variable_op)( 433 var, 434 value=value, 435 use_locking=use_locking, 436 name=name, 437 read_value=read_value) 438 return assign_add( 439 var, value, use_locking=use_locking, name=name, read_value=read_value) 440 441 def assign(self, var, value, use_locking=False, name=None, read_value=True): 442 if (tpu_util.enclosing_tpu_context() and 443 var.aggregation == variable_scope.VariableAggregation.NONE): 444 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 445 var, 446 value=value, 447 use_locking=use_locking, 448 name=name, 449 read_value=read_value) 450 return assign( 451 var, value, use_locking=use_locking, name=name, read_value=read_value) 452 453 def scatter_sub(self, *args, **kwargs): 454 raise NotImplementedError 455 456 def scatter_add(self, *args, **kwargs): 457 raise NotImplementedError 458 459 def scatter_max(self, *args, **kwargs): 460 raise NotImplementedError 461 462 def scatter_min(self, *args, **kwargs): 463 raise NotImplementedError 464 465 def scatter_mul(self, *args, **kwargs): 466 raise NotImplementedError 467 468 def scatter_div(self, *args, **kwargs): 469 raise NotImplementedError 470 471 def scatter_update(self, *args, **kwargs): 472 raise NotImplementedError 473 474 def _is_mirrored(self): 475 return True 476 477 478class TPUOnReadPolicy(values.OnReadPolicy): 479 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 480 481 This policy is created when `synchronization` is set to 482 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 483 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 484 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 485 scope. 486 """ 487 488 def assign_sub(self, var, *args, **kwargs): 489 if tpu_util.enclosing_tpu_context() is None: 490 return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs) 491 else: 492 return _make_raw_assign_fn( 493 gen_resource_variable_ops.assign_sub_variable_op)(var, *args, 494 **kwargs) 495 496 def assign_add(self, var, *args, **kwargs): 497 if tpu_util.enclosing_tpu_context() is None: 498 return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs) 499 else: 500 return _make_raw_assign_fn( 501 gen_resource_variable_ops.assign_add_variable_op)(var, *args, 502 **kwargs) 503 504 def assign(self, var, *args, **kwargs): 505 if tpu_util.enclosing_tpu_context() is None: 506 return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) 507 else: 508 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 509 var, *args, **kwargs) 510 511 def _is_mirrored(self): 512 return False 513 514 def scatter_sub(self, *args, **kwargs): 515 raise NotImplementedError 516 517 def scatter_add(self, *args, **kwargs): 518 raise NotImplementedError 519 520 def scatter_max(self, *args, **kwargs): 521 raise NotImplementedError 522 523 def scatter_min(self, *args, **kwargs): 524 raise NotImplementedError 525 526 def scatter_mul(self, *args, **kwargs): 527 raise NotImplementedError 528 529 def scatter_div(self, *args, **kwargs): 530 raise NotImplementedError 531 532 def scatter_update(self, *args, **kwargs): 533 raise NotImplementedError 534