1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains AutoCastVariable, a variable which automatically casts itself.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import values as distribute_values 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.ops import resource_variable_ops 24 25 26# TODO(reedwm): Make checkpointable? 27class AutoCastVariable(object): 28 """Variable that will cast itself to a different dtype in applicable contexts. 29 30 This class wraps a floating-point tf.Variable. It emulates the variable 31 interface and delegates to the wrapped variable, but it additionally will cast 32 the wrapped variable under a `Graph._enable_variable_auto_cast(dtype)` context 33 manager. 34 35 For example: 36 37 ``` 38 v = tf.Variable(1.0, dtype=tf.float32) 39 v = AutoCastVariable(v) 40 print(tf.identity(v).dtype) # tf.float32 41 with ops.get_default_graph()._enable_variable_auto_cast(tf.float16): 42 print(tf.identity(v).dtype) # tf.float16, as v will cast itself to float16 43 print(v.dtype) # tf.float16, as v.dtype also changes under the ctx manager. 44 ``` 45 46 The purpose of this class is to allow Keras layers to create variables in 47 float32, and automatically cast them to float16 or bfloat16 when the layer is 48 called. 49 """ 50 51 def __init__(self, variable): 52 """Creates an AutoCastVariable instance. 53 54 Args: 55 variable: A floating-point resource variable to wrap. 56 57 Raises: 58 ValueError: If `variable` is not a floating-point resource variable 59 """ 60 if not resource_variable_ops.is_resource_variable(variable): 61 raise ValueError('variable must be of type tf.ResourceVariable, but got: ' 62 '%s' % variable) 63 if not variable.dtype.is_floating: 64 raise ValueError('variable must be a floating point variable but has ' 65 'type: %s' % variable.dtype.name) 66 self._variable = variable 67 68 @property 69 def name(self): 70 return self._variable.name 71 72 def _should_cast(self): 73 """Returns True if this variable should be casted when accessed.""" 74 g = ops.get_default_graph() 75 # pylint:disable=protected-access 76 return (g._auto_cast_variable_read_dtype is not None and 77 self.true_dtype != g._auto_cast_variable_read_dtype) 78 # pylint:enable=protected-access 79 80 @property 81 def dtype(self): 82 """The dtype this variable will be casted to when read.""" 83 if self._should_cast(): 84 return ops.get_default_graph()._auto_cast_variable_read_dtype # pylint:disable=protected-access 85 else: 86 return self._variable.dtype 87 88 @property 89 def true_dtype(self): 90 """The dtype of the underlying variable, before any casts are done.""" 91 return self._variable.dtype 92 93 def value(self): 94 val = self._variable.value() 95 if not self._should_cast(): 96 return val 97 # We colocate_with(None) to ignore the existing device constraints, so that 98 # the cast is always done on the variable's device 99 with ops.colocate_with(None, ignore_existing=True): 100 with ops.device(val.device): 101 return math_ops.cast(val, self.dtype) 102 103 def read_value(self): 104 val = self._variable.read_value() 105 if not self._should_cast(): 106 return val 107 return math_ops.cast(val, self.dtype) 108 109 def sparse_read(self, indices, name=None): 110 """Reads the value of this variable sparsely, using `gather`.""" 111 val = self._variable.sparse_read(indices, name=name) 112 if not self._should_cast(): 113 return val 114 return math_ops.cast(val, self.dtype) 115 116 def assign(self, value, use_locking=None, name=None, read_value=True): 117 return self._variable.assign( 118 value, use_locking=use_locking, name=name, read_value=read_value) 119 120 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 121 return self._variable.assign_add( 122 delta, use_locking=use_locking, name=name, read_value=read_value) 123 124 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 125 return self._variable.assign_sub( 126 delta, use_locking=use_locking, name=name, read_value=read_value) 127 128 # TODO(reedwm): Support assigning variables with tf.assign(), var.scatter_add, 129 # etc. 130 131 def __getattr__(self, name): 132 return getattr(self._variable, name) 133 134 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 135 """Converts this variable to a tensor.""" 136 if not self._should_cast(): 137 return ops.internal_convert_to_tensor(self._variable, dtype, name, 138 as_ref) 139 # TODO(reedwm): Support as_ref? 140 assert not as_ref 141 if dtype is not None and not dtype.is_compatible_with(self.dtype): 142 raise ValueError( 143 'Incompatible type conversion requested to type {!r} for variable ' 144 'of type {!r}'.format(dtype.name, self.dtype.name)) 145 val = ops.internal_convert_to_tensor(self._variable, 146 self._variable.dtype, name, 147 as_ref=False) 148 with ops.colocate_with(None, ignore_existing=True): 149 with ops.device(val.device): 150 return math_ops.cast(val, self.dtype) 151 152 def _should_act_as_resource_variable(self): 153 """Pass resource_variable_ops.is_resource_variable check.""" 154 pass 155 156 # TODO(reedwm): Define operator overloads. 157 158 159ops.register_tensor_conversion_function( 160 AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access 161ops.register_dense_tensor_like_type(AutoCastVariable) 162 163 164# We have DistributedVariable subclass to pass 165# isinstance(..., DistributedVariable) checks when wrapping a 166# DistributedVariable. 167# TODO(reedwm): We should not wrap DistributedVariable, but instead have 168# DistributedVariable wrap AutoCastVariable. Subclassing DistributedVariable is 169# messy, because we do not fully implement the interface of DistributedVariable. 170class AutoCastDistributedVariable(AutoCastVariable, 171 distribute_values.DistributedVariable): 172 """Version of AutoCastVariable that subclasses DistributedVariable.""" 173 174 def __init__(self, variable): 175 if not isinstance(variable, distribute_values.DistributedValues): 176 raise ValueError('variable must be of type DistributedValues, ' 177 'but got: %s' % variable) 178 super(AutoCastDistributedVariable, self).__init__(variable) 179