1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing TPU distributed values. 16 17Note that the tests are in values_test.py . 18 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import contextlib 26 27from tensorflow.python.distribute import packed_distributed_variable as packed 28from tensorflow.python.distribute import tpu_util 29from tensorflow.python.distribute import values 30from tensorflow.python.distribute import values_util 31from tensorflow.python.eager import context 32from tensorflow.python.eager import tape 33from tensorflow.python.framework import ops 34from tensorflow.python.ops import gen_resource_variable_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import variable_scope 37 38 39@contextlib.contextmanager 40def _maybe_enter_graph(tensor): 41 # Note: might have an eager tensor but not be executing eagerly when 42 # building functions. 43 if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or 44 ops.has_default_graph()): 45 yield 46 else: 47 with tensor.graph.as_default(): 48 yield 49 50 51@contextlib.contextmanager 52def _maybe_on_device(var): 53 # Add a device scope for packed variables. 54 if isinstance(var, packed.PackedVarAndDevice): 55 with ops.device(var.device): 56 yield 57 else: 58 yield 59 60 61def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring 62 63 def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring 64 del use_locking # Unused. 65 66 handle = var.handle 67 with _maybe_enter_graph(handle), _maybe_on_device(var): 68 op = raw_assign_fn( 69 handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) 70 with ops.control_dependencies([op]): 71 return var._read_variable_op() if read_value else op # pylint: disable=protected-access 72 73 return assign_fn 74 75 76_scatter_error_msg = ("{op_name} is only supported for distributed " 77 "variable (variable created within certain " 78 "`tf.distribute.Strategy` scope) with NONE " 79 " aggregation, got: {aggregation}.") 80 81 82def _make_raw_scatter_xxx_fn(raw_scatter_xxx_fn): 83 """Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle.""" 84 85 def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None): # pylint: disable=missing-docstring 86 del use_locking # Unused. 87 88 handle = var.handle 89 with _maybe_enter_graph(handle), _maybe_on_device(var): 90 op = raw_scatter_xxx_fn( 91 handle, 92 sparse_delta.indices, 93 ops.convert_to_tensor(sparse_delta.values, var.dtype), 94 name=name) 95 with ops.control_dependencies([op]): 96 return var._read_variable_op() # pylint: disable=protected-access 97 98 return scatter_xxx_fn 99 100 101class TPUVariableMixin(object): 102 """Mixin for TPU variables.""" 103 104 def __init__(self, *args, **kwargs): 105 super(TPUVariableMixin, self).__init__(*args, **kwargs) 106 107 # Handle ID is needed for `get_replicated_var_handle` to cache the variables 108 # correctly since in eager mode different variables can have the same name. 109 if ops.executing_eagerly_outside_functions(): 110 self._handle_id = self._common_name + "_" + str(id(self._primary)) 111 else: 112 self._handle_id = self._common_name 113 114 def __getattr__(self, name): 115 if tpu_util.enclosing_tpu_context() is None: 116 return super(TPUVariableMixin, self).__getattr__(name) 117 else: 118 raise AttributeError( 119 f"`TPUVariableMixin.{name}` not accessible within a TPU context.") 120 121 def get(self): 122 if tpu_util.enclosing_tpu_context() is None: 123 return super(TPUVariableMixin, self).get() 124 else: 125 raise NotImplementedError( 126 "`TPUVariableMixin.get()` is not supported within a TPU context.") 127 128 def _get_as_operand(self): 129 return self.read_value() 130 131 def _is_mirrored(self): 132 raise NotImplementedError( 133 "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") 134 135 @property 136 def handle(self): 137 """The handle by which this variable can be accessed.""" 138 # If we're in a tpu.rewrite(), return the replicated handle. 139 tpu_context = tpu_util.enclosing_tpu_context() 140 if tpu_context is None or context.executing_eagerly(): 141 var = self._get_on_device_or_primary() 142 if isinstance(var, packed.PackedVarAndDevice): 143 return var.on_device_handle() 144 else: 145 return var.handle 146 else: 147 is_packed = self._packed_var is not None 148 val = self._values 149 if is_packed: 150 val = [self._packed_var] 151 152 return tpu_context.get_replicated_var_handle(self._handle_id, val, 153 self._is_mirrored(), 154 is_packed) 155 156 @property 157 def device(self): 158 return self.handle.device 159 160 def _read_variable_op(self): 161 """Reads the value of this variable.""" 162 if self.trainable: 163 tape.variable_accessed(self) 164 165 handle = self.handle 166 if getattr(handle, "is_packed", False): 167 # Add a device scope for a packed variable handle. 168 with ops.device(self._get_on_device_or_primary().device): 169 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 170 else: 171 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 172 173 def read_value(self): 174 if tpu_util.enclosing_tpu_context() is None: 175 return super(TPUVariableMixin, self).read_value() 176 else: 177 return self._read_variable_op() 178 179 def value(self): 180 if tpu_util.enclosing_tpu_context() is None: 181 return super(TPUVariableMixin, self).value() 182 else: 183 return self._read_variable_op() 184 185 def _as_graph_element(self): 186 if tpu_util.enclosing_tpu_context() is None: 187 return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access 188 else: 189 return None 190 191 @property 192 def op(self): 193 if values_util.is_saving_non_distributed(): 194 return self._primary.op 195 return values.DistributedVarOp(self._primary.op.name, 196 self._primary.op.graph, 197 self._primary.op.traceback, 198 self._primary.op.type) 199 200 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 201 """Converts a variable to a tensor.""" 202 # pylint: disable=protected-access 203 if tpu_util.enclosing_tpu_context() is None: 204 return super(TPUVariableMixin, self)._dense_var_to_tensor( 205 dtype=dtype, name=name, as_ref=as_ref) 206 # pylint: enable=protected-access 207 elif dtype is not None and dtype != self.dtype: 208 return math_ops.cast(self.read_value(), dtype) 209 else: 210 return self.handle if as_ref else self.read_value() 211 212 213class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): 214 """DistributedVariable subclass for TPUStrategy.""" 215 216 def _is_mirrored(self): 217 return self._policy._is_mirrored() # pylint: disable=protected-access 218 219 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 220 if values_util.is_saving_non_distributed(): 221 return self._primary.assign_sub(value, use_locking, name, read_value) 222 return self._policy.assign_sub( 223 self, value, use_locking=use_locking, name=name, read_value=read_value) 224 225 def assign_add(self, value, use_locking=False, name=None, read_value=True): 226 if values_util.is_saving_non_distributed(): 227 return self._primary.assign_add(value, use_locking, name, read_value) 228 return self._policy.assign_add( 229 self, value, use_locking=use_locking, name=name, read_value=read_value) 230 231 def assign(self, value, use_locking=False, name=None, read_value=True): 232 if values_util.is_saving_non_distributed(): 233 return self._primary.assign(value, use_locking, name, read_value) 234 return self._policy.assign( 235 self, value, use_locking=use_locking, name=name, read_value=read_value) 236 237 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 238 if values_util.is_saving_non_distributed(): 239 return self._primary.scatter_sub(sparse_delta, use_locking, name) 240 return self._policy.scatter_sub( 241 self, sparse_delta, use_locking=use_locking, name=name) 242 243 def scatter_add(self, sparse_delta, use_locking=False, name=None): 244 if values_util.is_saving_non_distributed(): 245 return self._primary.scatter_add(sparse_delta, use_locking, name) 246 return self._policy.scatter_add( 247 self, sparse_delta, use_locking=use_locking, name=name) 248 249 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 250 if values_util.is_saving_non_distributed(): 251 return self._primary.scatter_mul(sparse_delta, use_locking, name) 252 return self._policy.scatter_mul( 253 self, sparse_delta, use_locking=use_locking, name=name) 254 255 def scatter_div(self, sparse_delta, use_locking=False, name=None): 256 if values_util.is_saving_non_distributed(): 257 return self._primary.scatter_div(sparse_delta, use_locking, name) 258 return self._policy.scatter_div( 259 self, sparse_delta, use_locking=use_locking, name=name) 260 261 def scatter_min(self, sparse_delta, use_locking=False, name=None): 262 if values_util.is_saving_non_distributed(): 263 return self._primary.scatter_min(sparse_delta, use_locking, name) 264 return self._policy.scatter_min( 265 self, sparse_delta, use_locking=use_locking, name=name) 266 267 def scatter_max(self, sparse_delta, use_locking=False, name=None): 268 if values_util.is_saving_non_distributed(): 269 return self._primary.scatter_max(sparse_delta, use_locking, name) 270 return self._policy.scatter_max( 271 self, sparse_delta, use_locking=use_locking, name=name) 272 273 def scatter_update(self, sparse_delta, use_locking=False, name=None): 274 if values_util.is_saving_non_distributed(): 275 return self._primary.scatter_update(sparse_delta, use_locking, name) 276 return self._policy.scatter_update( 277 self, sparse_delta, use_locking=use_locking, name=name) 278 279 280class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): 281 """Holds a map from replica to TPU variables whose values are kept in sync.""" 282 283 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 284 if (tpu_util.enclosing_tpu_context() and 285 self.aggregation == variable_scope.VariableAggregation.NONE): 286 return _make_raw_assign_fn( 287 gen_resource_variable_ops.assign_sub_variable_op)( 288 self, 289 value=value, 290 use_locking=use_locking, 291 name=name, 292 read_value=read_value) 293 return assign_sub( 294 self, value, use_locking=use_locking, name=name, read_value=read_value) 295 296 def assign_add(self, value, use_locking=False, name=None, read_value=True): 297 if (tpu_util.enclosing_tpu_context() and 298 self.aggregation == variable_scope.VariableAggregation.NONE): 299 return _make_raw_assign_fn( 300 gen_resource_variable_ops.assign_add_variable_op)( 301 self, 302 value=value, 303 use_locking=use_locking, 304 name=name, 305 read_value=read_value) 306 return assign_add( 307 self, value, use_locking=use_locking, name=name, read_value=read_value) 308 309 def assign(self, value, use_locking=False, name=None, read_value=True): 310 if (tpu_util.enclosing_tpu_context() and 311 self.aggregation == variable_scope.VariableAggregation.NONE): 312 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 313 self, 314 value=value, 315 use_locking=use_locking, 316 name=name, 317 read_value=read_value) 318 return assign( 319 self, value, use_locking=use_locking, name=name, read_value=read_value) 320 321 def scatter_sub(self, *args, **kwargs): 322 if values_util.is_saving_non_distributed(): 323 return self._primary.scatter_sub(*args, **kwargs) 324 raise NotImplementedError 325 326 def scatter_add(self, *args, **kwargs): 327 if values_util.is_saving_non_distributed(): 328 return self._primary.scatter_add(*args, **kwargs) 329 raise NotImplementedError 330 331 def scatter_max(self, *args, **kwargs): 332 if values_util.is_saving_non_distributed(): 333 return self._primary.scatter_max(*args, **kwargs) 334 raise NotImplementedError 335 336 def scatter_min(self, *args, **kwargs): 337 if values_util.is_saving_non_distributed(): 338 return self._primary.scatter_min(*args, **kwargs) 339 raise NotImplementedError 340 341 def scatter_mul(self, *args, **kwargs): 342 if values_util.is_saving_non_distributed(): 343 return self._primary.scatter_mul(*args, **kwargs) 344 raise NotImplementedError 345 346 def scatter_div(self, *args, **kwargs): 347 if values_util.is_saving_non_distributed(): 348 return self._primary.scatter_div(*args, **kwargs) 349 raise NotImplementedError 350 351 def scatter_update(self, *args, **kwargs): 352 if values_util.is_saving_non_distributed(): 353 return self._primary.scatter_update(*args, **kwargs) 354 raise NotImplementedError 355 356 def _is_mirrored(self): 357 return True 358 359 360class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): 361 """Holds a map from replica to variables whose values are reduced on save.""" 362 363 def assign_sub(self, *args, **kwargs): 364 if tpu_util.enclosing_tpu_context() is None: 365 return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) 366 else: 367 return _make_raw_assign_fn( 368 gen_resource_variable_ops.assign_sub_variable_op)(self, *args, 369 **kwargs) 370 371 def assign_add(self, *args, **kwargs): 372 if tpu_util.enclosing_tpu_context() is None: 373 return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) 374 else: 375 return _make_raw_assign_fn( 376 gen_resource_variable_ops.assign_add_variable_op)(self, *args, 377 **kwargs) 378 379 def assign(self, *args, **kwargs): 380 if tpu_util.enclosing_tpu_context() is None: 381 return values.SyncOnReadVariable.assign(self, *args, **kwargs) 382 else: 383 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 384 self, *args, **kwargs) 385 386 def _is_mirrored(self): 387 return False 388 389 390# Common method between OnWrite and Mirrored variables. 391def assign_sub(var, value, use_locking=False, name=None, read_value=True): 392 assign_sub_fn = _make_raw_assign_fn( 393 gen_resource_variable_ops.assign_sub_variable_op) 394 return var._update( # pylint: disable=protected-access 395 update_fn=assign_sub_fn, 396 value=value, 397 use_locking=use_locking, 398 name=name, 399 read_value=read_value) 400 401 402def assign_add(var, value, use_locking=False, name=None, read_value=True): 403 assign_add_fn = _make_raw_assign_fn( 404 gen_resource_variable_ops.assign_add_variable_op) 405 return var._update( # pylint: disable=protected-access 406 update_fn=assign_add_fn, 407 value=value, 408 use_locking=use_locking, 409 name=name, 410 read_value=read_value) 411 412 413def assign(var, value, use_locking=False, name=None, read_value=True): 414 assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op) 415 return var._update( # pylint: disable=protected-access 416 update_fn=assign_fn, 417 value=value, 418 use_locking=use_locking, 419 name=name, 420 read_value=read_value) 421 422 423class TPUOnWritePolicy(values.OnWritePolicy): 424 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 425 426 This policy is created when `synchronization` is set to 427 `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`. 428 """ 429 430 def assign_sub(self, 431 var, 432 value, 433 use_locking=False, 434 name=None, 435 read_value=True): 436 if (tpu_util.enclosing_tpu_context() and 437 var.aggregation == variable_scope.VariableAggregation.NONE): 438 return _make_raw_assign_fn( 439 gen_resource_variable_ops.assign_sub_variable_op)( 440 var, 441 value=value, 442 use_locking=use_locking, 443 name=name, 444 read_value=read_value) 445 return assign_sub( 446 var, value, use_locking=use_locking, name=name, read_value=read_value) 447 448 def assign_add(self, 449 var, 450 value, 451 use_locking=False, 452 name=None, 453 read_value=True): 454 if (tpu_util.enclosing_tpu_context() and 455 var.aggregation == variable_scope.VariableAggregation.NONE): 456 return _make_raw_assign_fn( 457 gen_resource_variable_ops.assign_add_variable_op)( 458 var, 459 value=value, 460 use_locking=use_locking, 461 name=name, 462 read_value=read_value) 463 return assign_add( 464 var, value, use_locking=use_locking, name=name, read_value=read_value) 465 466 def assign(self, var, value, use_locking=False, name=None, read_value=True): 467 if (tpu_util.enclosing_tpu_context() and 468 var.aggregation == variable_scope.VariableAggregation.NONE): 469 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 470 var, 471 value=value, 472 use_locking=use_locking, 473 name=name, 474 read_value=read_value) 475 return assign( 476 var, value, use_locking=use_locking, name=name, read_value=read_value) 477 478 def _scatter_xxx(self, 479 raw_scater_xxx_fn, 480 op_name, 481 var, 482 sparse_delta, 483 use_locking=False, 484 name=None): 485 scater_xxx_fn = _make_raw_scatter_xxx_fn(raw_scater_xxx_fn) 486 if tpu_util.enclosing_tpu_context(): 487 if self._aggregation != variable_scope.VariableAggregation.NONE: 488 raise NotImplementedError( 489 _scatter_error_msg.format( 490 op_name=op_name, aggregation=self._aggregation)) 491 return scater_xxx_fn( 492 var, sparse_delta=sparse_delta, use_locking=use_locking, name=name) 493 else: 494 return var._update( # pylint: disable=protected-access 495 update_fn=scater_xxx_fn, 496 value=sparse_delta, 497 use_locking=use_locking, 498 name=name) 499 500 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 501 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub, 502 "scatter_sub", var, sparse_delta, use_locking, 503 name) 504 505 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 506 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add, 507 "scatter_add", var, sparse_delta, use_locking, 508 name) 509 510 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 511 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max, 512 "scatter_max", var, sparse_delta, use_locking, 513 name) 514 515 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 516 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min, 517 "scatter_min", var, sparse_delta, use_locking, 518 name) 519 520 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 521 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul, 522 "scatter_mul", var, sparse_delta, use_locking, 523 name) 524 525 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 526 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div, 527 "scatter_div", var, sparse_delta, use_locking, 528 name) 529 530 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 531 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update, 532 "scatter_update", var, sparse_delta, use_locking, 533 name) 534 535 def _is_mirrored(self): 536 return True 537 538 539class TPUOnReadPolicy(values.OnReadPolicy): 540 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 541 542 This policy is created when `synchronization` is set to 543 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 544 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 545 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 546 scope. 547 """ 548 549 def assign_sub(self, var, *args, **kwargs): 550 if tpu_util.enclosing_tpu_context() is None: 551 return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs) 552 else: 553 return _make_raw_assign_fn( 554 gen_resource_variable_ops.assign_sub_variable_op)(var, *args, 555 **kwargs) 556 557 def assign_add(self, var, *args, **kwargs): 558 if tpu_util.enclosing_tpu_context() is None: 559 return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs) 560 else: 561 return _make_raw_assign_fn( 562 gen_resource_variable_ops.assign_add_variable_op)(var, *args, 563 **kwargs) 564 565 def assign(self, var, *args, **kwargs): 566 if tpu_util.enclosing_tpu_context() is None: 567 return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) 568 else: 569 return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( 570 var, *args, **kwargs) 571 572 def _is_mirrored(self): 573 return False 574 575 def scatter_sub(self, *args, **kwargs): 576 raise NotImplementedError 577 578 def scatter_add(self, *args, **kwargs): 579 raise NotImplementedError 580 581 def scatter_max(self, *args, **kwargs): 582 raise NotImplementedError 583 584 def scatter_min(self, *args, **kwargs): 585 raise NotImplementedError 586 587 def scatter_mul(self, *args, **kwargs): 588 raise NotImplementedError 589 590 def scatter_div(self, *args, **kwargs): 591 raise NotImplementedError 592 593 def scatter_update(self, *args, **kwargs): 594 raise NotImplementedError 595