1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains AutoCastVariable, a variable which automatically casts itself.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import threading 21 22from tensorflow.python.distribute import distribute_utils 23from tensorflow.python.distribute import ps_values as ps_distribute_values 24from tensorflow.python.eager import context 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.ops import variables 29from tensorflow.python.types import core 30 31 32# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or 33# None if AutoCastVariables should not be cast. 34_autocast_dtype = threading.local() 35 36 37def numpy_text(tensor, is_repr=False): 38 """Human readable representation of a tensor's numpy value.""" 39 if tensor.dtype.is_numpy_compatible: 40 # pylint: disable=protected-access 41 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 42 # pylint: enable=protected-access 43 else: 44 text = '<unprintable>' 45 if '\n' in text: 46 text = '\n' + text 47 return text 48 49 50class AutoCastVariable(variables.Variable, core.Tensor): 51 """Variable that will cast itself to a different dtype in applicable contexts. 52 53 This class wraps a floating-point `tf.Variable`. It emulates the variable 54 interface and delegates to the wrapped variable, but it additionally will cast 55 the wrapped variable under an `enable_auto_cast_variables(dtype)` context 56 manager. 57 58 For example: 59 60 >>> v = tf.Variable(1.0, dtype=tf.float32) 61 >>> v = AutoCastVariable(v) 62 >>> tf.identity(v).dtype 63 tf.float32 64 >>> with enable_auto_cast_variables(tf.float16): 65 ... tf.identity(v).dtype 66 tf.float16 67 68 The purpose of this class is to allow Keras layers to create variables in 69 float32, and automatically cast them to float16 or bfloat16 when the layer is 70 called. 71 """ 72 73 def __init__(self, variable): 74 """Creates an AutoCastVariable instance. 75 76 Args: 77 variable: A floating-point resource variable to wrap. 78 79 Raises: 80 ValueError: If `variable` is not a floating-point resource variable 81 """ 82 if not isinstance(variable, variables.Variable): 83 raise ValueError('variable must be of type tf.ResourceVariable, but got: ' 84 '%s' % variable) 85 if not variable.dtype.is_floating: 86 raise ValueError('variable must be a floating point variable but has ' 87 'type: %s' % variable.dtype.name) 88 self._variable = variable 89 # 'delegate' means AutoCastVariable.op return self._variable.op, which will 90 # raise an AttributeError in Eager (as intended). If set to any other value, 91 # AutoCastVariable.op returns that value instead, which is used to set the 92 # op attribute in AutoCastVariable.assign(). 93 self._op = 'delegate' 94 95 def _should_cast(self): 96 """Returns True if this variable should be casted when accessed.""" 97 autocast_dtype = getattr(_autocast_dtype, 'dtype', None) 98 return autocast_dtype is not None and self.dtype != autocast_dtype 99 100 @property 101 def dtype(self): 102 """The dtype of the underlying variable, before any casts are done.""" 103 return self._variable.dtype 104 105 @property 106 def true_dtype(self): 107 """Deprecated alias of `dtype`.""" 108 return self._variable.dtype 109 110 @property 111 def _cast_dtype(self): 112 dtype = getattr(_autocast_dtype, 'dtype', None) 113 return dtype or self._variable.dtype 114 115 def value(self): 116 val = self._variable.value() 117 if not self._should_cast(): 118 return val 119 return math_ops.cast(val, self._cast_dtype) 120 121 def read_value(self): 122 val = self._variable.read_value() 123 return math_ops.cast(val, self._cast_dtype) 124 125 def sparse_read(self, indices, name=None): 126 """Reads the value of this variable sparsely, using `gather`.""" 127 val = self._variable.sparse_read(indices, name=name) 128 return math_ops.cast(val, self._cast_dtype) 129 130 def gather_nd(self, indices, name=None): 131 """Gather slices of the variable into a Tensor.""" 132 val = self._variable.gather_nd(indices, name=name) 133 return math_ops.cast(val, self._cast_dtype) 134 135 def __getattr__(self, name): 136 return getattr(self._variable, name) 137 138 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 139 """Converts this variable to a tensor.""" 140 if as_ref: 141 # This ValueError should not occur in practice since it is impossible to 142 # pass as_ref=True using public APIs. 143 raise ValueError('Cannot convert AutoCastVariable to a tensor if ' 144 'as_ref=True is passed to convert_to_tensor') 145 if not self._should_cast(): 146 return ops.convert_to_tensor_v2_with_dispatch(self._variable, dtype=dtype, 147 name=name) 148 if dtype is not None and not dtype.is_compatible_with(self._cast_dtype): 149 raise ValueError( 150 'Incompatible type conversion requested to type {!r} for ' 151 'AutoCastVariable which is casted to type {!r}'.format( 152 dtype.name, self._cast_dtype.name)) 153 val = ops.convert_to_tensor_v2_with_dispatch( 154 self._variable, dtype=self._variable.dtype, name=name) 155 return math_ops.cast(val, self._cast_dtype) 156 157 def _should_act_as_resource_variable(self): 158 """Pass resource_variable_ops.is_resource_variable check.""" 159 pass 160 161 def __repr__(self): 162 if context.executing_eagerly() and not self._in_graph_mode: 163 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} " 164 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, ' 165 'numpy={np_repr}>') 166 return repr_str.format( 167 v=self, np_repr=numpy_text(self.read_value(), is_repr=True)) 168 else: 169 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} " 170 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>') 171 return repr_str.format(v=self) 172 173 # Method delegations: We delegate the following methods to self._variable. 174 # Each of these methods simply calls the same method on self._variable. The 175 # base Variable raises NotImplementedError for most of these, so we must 176 # override them. 177 # 178 # We do not define the following methods from Variable for the following 179 # reasons: 180 # * 'count_up_to': This method only applies to int variables, which cannot 181 # be wrapped with an AutoCastVariable. 182 # * 'ref': Instead we inherit the definition from Variable. 183 # If we defined and delegated to Variable, the ref of an AutoCastVariable 184 # would be the same as the ref of the underlying variable, which would be 185 # strange as they are different Python objects. 186 187 def set_shape(self, shape): 188 return self._variable.set_shape(self, shape) 189 190 @property 191 def trainable(self): 192 return self._variable.trainable 193 194 @property 195 def synchronization(self): 196 return self._variable.synchronization 197 198 @property 199 def aggregation(self): 200 return self._variable.aggregation 201 202 def eval(self, session=None): 203 return self._variable.eval(session) 204 205 def initialized_value(self): 206 return self._variable.initialized_value() 207 208 @property 209 def initial_value(self): 210 return self._variable.initial_value 211 212 @property 213 def constraint(self): 214 return self._variable.constraint 215 216 def _apply_assign_update(self, 217 update_fn, 218 value, 219 use_locking=None, 220 name=None, 221 read_value=True): 222 # TODO(b/146181571): This logic can be simplified once 223 # DistributedVariable.assign returns a DistributedVariable. Currently for 224 # MirroredStrategy, it returns a Mirrored value. 225 if ops.executing_eagerly_outside_functions(): 226 assign_op = update_fn(value, use_locking, name, False) 227 if read_value: 228 # We create a new AutoCastVariable with the same underlying tf.Variable. 229 # The new AutoCastVariable is identical except the 'op' attribute is 230 # defined. This matches the behavior of tf.Variable.assign. 231 var = create_autocast_variable(self._variable) 232 var._op = assign_op # pylint:disable=protected-access 233 return var 234 return assign_op 235 236 # Fallback to wrapping the returned variable in graph mode if possible 237 assign_var = update_fn(value, use_locking, name, read_value) 238 if read_value and resource_variable_ops.is_resource_variable(assign_var): 239 return create_autocast_variable(assign_var) 240 return assign_var 241 242 def _apply_update(self, update_fn, *args, **kwargs): 243 update_var = update_fn(*args, **kwargs) 244 if ops.executing_eagerly_outside_functions(): 245 return self 246 247 # Fallback to wrapping the returned variable in graph mode if possible 248 if resource_variable_ops.is_resource_variable(update_var): 249 return create_autocast_variable(update_var) 250 return update_var 251 252 def assign(self, value, use_locking=None, name=None, read_value=True): 253 return self._apply_assign_update(self._variable.assign, value, use_locking, 254 name, read_value) 255 256 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 257 return self._apply_assign_update(self._variable.assign_add, delta, 258 use_locking, name, read_value) 259 260 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 261 return self._apply_assign_update(self._variable.assign_sub, delta, 262 use_locking, name, read_value) 263 264 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 265 return self._apply_update(self._variable.scatter_sub, sparse_delta, 266 use_locking, name) 267 268 def scatter_add(self, sparse_delta, use_locking=False, name=None): 269 return self._apply_update(self._variable.scatter_add, sparse_delta, 270 use_locking, name) 271 272 def scatter_max(self, sparse_delta, use_locking=False, name=None): 273 return self._apply_update(self._variable.scatter_max, sparse_delta, 274 use_locking, name) 275 276 def scatter_min(self, sparse_delta, use_locking=False, name=None): 277 return self._apply_update(self._variable.scatter_min, sparse_delta, 278 use_locking, name) 279 280 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 281 return self._apply_update(self._variable.scatter_mul, sparse_delta, 282 use_locking, name) 283 284 def scatter_div(self, sparse_delta, use_locking=False, name=None): 285 return self._apply_update(self._variable.scatter_div, sparse_delta, 286 use_locking, name) 287 288 def scatter_update(self, sparse_delta, use_locking=False, name=None): 289 return self._apply_update(self._variable.scatter_update, sparse_delta, 290 use_locking, name) 291 292 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 293 return self._apply_update(self._variable.batch_scatter_update, sparse_delta, 294 use_locking, name) 295 296 def scatter_nd_sub(self, indices, updates, name=None): 297 return self._apply_update(self._variable.scatter_nd_sub, indices, updates, 298 name) 299 300 def scatter_nd_add(self, indices, updates, name=None): 301 return self._apply_update(self._variable.scatter_nd_add, indices, updates, 302 name) 303 304 def scatter_nd_update(self, indices, updates, name=None): 305 return self._apply_update(self._variable.scatter_nd_update, indices, 306 updates, name) 307 308 def load(self, value, session=None): 309 return self._variable.load(value, session) 310 311 @property 312 def name(self): 313 return self._variable.name 314 315 @property 316 def _shared_name(self): 317 return self._variable._shared_name # pylint:disable=protected-access 318 319 @property 320 def initializer(self): 321 return self._variable.initializer 322 323 @property 324 def device(self): 325 return self._variable.device 326 327 @property 328 def op(self): 329 if self._op == 'delegate': 330 return self._variable.op 331 return self._op 332 333 def _as_graph_element(self): 334 graph_element = self._variable._as_graph_element() # pylint:disable=protected-access 335 if graph_element is None: 336 return self._op 337 return graph_element 338 339 @property 340 def graph(self): 341 return self._variable.graph 342 343 @property 344 def shape(self): 345 return self._variable.shape 346 347 def get_shape(self): 348 return self._variable.get_shape() 349 350 def _gather_saveables_for_checkpoint(self): 351 # By delegating this method to the wrapped variable, checkpoints with 352 # AutoCastVariables are identical to checkpoints with normal variables. 353 # Therefore models checkpointed with AutoCastVariables can be restored on 354 # models with normal variables, and vice versa. 355 return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access 356 357 def _map_resources(self, save_options): 358 # By delegating this method to the wrapped variable, SavedModel with 359 # AutoCastVariables are identical to SavedModel with normal variables. 360 obj_map, resource_map = self._variable._map_resources(save_options) # pylint:disable=protected-access 361 obj_map[self] = obj_map[self._variable] 362 return obj_map, resource_map 363 364 # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in 365 # to_proto(). 366 def to_proto(self, export_scope=None): 367 return self._variable.to_proto(export_scope) 368 369 def from_proto(self, variable_def, import_scope=None): 370 return self._variable.from_proto(variable_def, import_scope) 371 372 # Delegate the private attributes _handle_name and _initializer_op to 373 # self._variable. SavedModel sets these attributes when loading a model. For 374 # example, it sets _handle_name here: 375 # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211 376 # We need to expose these attributes on AutoCastVariable as well for 377 # SavedModel to work properly. 378 # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing 379 # private attributes is hacky and difficult to maintain. 380 @property 381 def _handle_name(self): 382 return self._variable._handle_name # pylint: disable=protected-access 383 384 @_handle_name.setter 385 def _handle_name(self, handle_name): 386 self._variable._handle_name = handle_name # pylint: disable=protected-access 387 388 @property 389 def _initializer_op(self): 390 return self._variable._initializer_op # pylint: disable=protected-access 391 392 @_initializer_op.setter 393 def _initializer_op(self, initializer_op): 394 self._variable._initializer_op = initializer_op # pylint: disable=protected-access 395 396 # Operator overloads: 397 # Note we only overload operators that support floating-point types, as 398 # non-float variables cannot be wrapped with an AutoCastVariable. 399 # Also note: We call read_value() instead of value(), because value() causes 400 # gradients not to work properly when TPUStrategy is used: b/143380936 401 402 def __add__(self, o): 403 return self.read_value() + o 404 405 def __radd__(self, o): 406 return o + self.read_value() 407 408 def __sub__(self, o): 409 return self.read_value() - o 410 411 def __rsub__(self, o): 412 return o - self.read_value() 413 414 def __mul__(self, o): 415 return self.read_value() * o 416 417 def __rmul__(self, o): 418 return o * self.read_value() 419 420 def __truediv__(self, o): 421 return self.read_value() / o 422 423 def __rtruediv__(self, o): 424 return o / self.read_value() 425 426 def __floordiv__(self, o): 427 return self.read_value() // o 428 429 def __rfloordiv__(self, o): 430 return o // self.read_value() 431 432 def __mod__(self, o): 433 return self.read_value() % o 434 435 def __rmod__(self, o): 436 return o % self.read_value() 437 438 def __lt__(self, o): 439 return self.read_value() < o 440 441 def __le__(self, o): 442 return self.read_value() <= o 443 444 def __gt__(self, o): 445 return self.read_value() > o 446 447 def __ge__(self, o): 448 return self.read_value() >= o 449 450 def __getitem__(self, o): 451 return self.read_value()[o] 452 453 def __pow__(self, o, modulo=None): 454 return pow(self.read_value(), o, modulo) 455 456 def __rpow__(self, o): 457 return pow(o, self.read_value()) 458 459 def __neg__(self): 460 return -self.read_value() 461 462 def __abs__(self): 463 return abs(self.read_value()) 464 465 def __div__(self, o): 466 try: 467 return self.read_value().__div__(o) 468 except AttributeError: 469 # See https://docs.python.org/3/library/constants.html#NotImplemented 470 return NotImplemented 471 472 def __rdiv__(self, o): 473 try: 474 return self.read_value().__rdiv__(o) 475 except AttributeError: 476 # See https://docs.python.org/3/library/constants.html#NotImplemented 477 return NotImplemented 478 479 def __matmul__(self, o): 480 try: 481 return self.read_value().__matmul__(o) 482 except AttributeError: 483 # See https://docs.python.org/3/library/constants.html#NotImplemented 484 return NotImplemented 485 486 def __rmatmul__(self, o): 487 try: 488 return self.read_value().__rmatmul__(o) 489 except AttributeError: 490 # See https://docs.python.org/3/library/constants.html#NotImplemented 491 return NotImplemented 492 493 # pylint: enable=multiple-statements 494 495 496ops.register_tensor_conversion_function(AutoCastVariable, 497 AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access 498 499 500def create_autocast_variable(variable): 501 """Creates an AutoCastVariable that wraps another variable. 502 503 This typically just returns `AutoCastVariable(variable)`. But, if the variable 504 is a DistributedVariable or one of its subclasses, we instead dynamically 505 create a class that subclasses from both AutoCastVariable and 506 variable.__class__. This is so the returned variable will still pass 507 `isinstance(variable, variable.__class__)`, which is required for 508 DistributedVariables and its subclasses to work properly. 509 510 Args: 511 variable: A floating-point resource variable to wrap. 512 513 Returns: 514 An AutoCastVariable that wraps the variable. 515 """ 516 if (not distribute_utils.is_distributed_variable(variable) and 517 not isinstance(variable, ps_distribute_values.AggregatingVariable)): 518 return AutoCastVariable(variable) 519 520 class AutoCastDistributedVariable(AutoCastVariable, variable.__class__): 521 """An AutoCastVariable that also subclasses from variable.__class__. 522 523 variable.__class__ is either a DistributedVariable or an 524 AggregatingVariable. 525 """ 526 527 def __repr__(self): 528 if issubclass(ps_distribute_values.AggregatingVariable, 529 variable.__class__): 530 # AggregatingVariable's __repr__ simply calls super.__repr__. So we do 531 # the same here for consistency, which calls AutoCastVariable.__repr__. 532 return super(AutoCastDistributedVariable, self).__repr__() 533 534 # pylint: disable=missing-format-attribute 535 return ('<AutoCastDistributedVariable dtype={v.dtype.name} ' 536 'dtype_to_cast_to={v._cast_dtype.name} ' 537 'inner_variable={v._variable}>' 538 ).format(v=self) 539 # pylint: enable=missing-format-attribute 540 541 return AutoCastDistributedVariable(variable) 542 543 544class enable_auto_cast_variables(object): # pylint:disable=invalid-name 545 """Context manager which enables the autocasting of `AutoCastVariable`s. 546 547 Under this context manager, `AutoCastVariable`s will be cast to `dtype` if 548 `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast. 549 """ 550 551 __slots__ = ['_dtype', '_prev_dtype'] 552 553 def __init__(self, dtype): 554 if dtype and not dtype.is_floating: 555 dtype = None 556 self._dtype = dtype 557 558 def __enter__(self): 559 self._prev_dtype = getattr(_autocast_dtype, 'dtype', None) 560 _autocast_dtype.dtype = self._dtype 561 562 def __exit__(self, type_arg, value_arg, traceback_arg): 563 _autocast_dtype.dtype = self._prev_dtype 564