1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Decorator to overrides the gradient for a function.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import pywrap_tensorflow 22from tensorflow.python.eager import backprop 23from tensorflow.python.eager import context 24from tensorflow.python.eager import tape as tape_lib 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_array_ops 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.util import nest 33from tensorflow.python.util import tf_decorator 34from tensorflow.python.util import tf_inspect 35from tensorflow.python.util.tf_export import tf_export 36 37 38def copy_handle_data(source_t, target_t): 39 """Copies HandleData for variant and resource type tensors if available. 40 41 The CppShapeInferenceResult::HandleData proto contains information about the 42 shapes and types of the element tensors of resource/variant type tensors. 43 We need to copy this across function boundaries, i.e., when capturing a 44 placeholder or when returning a function tensor as output. If we don't do this 45 the element tensors will have unknown shapes, e.g., if a TensorList variant 46 tensor is captured as a placeholder, elements popped from that list would have 47 unknown shape. 48 49 Args: 50 source_t: The tensor to copy HandleData from. 51 target_t: The tensor to copy HandleData to. 52 """ 53 if (target_t.dtype == dtypes.resource or 54 target_t.dtype == dtypes.variant): 55 if isinstance(source_t, ops.EagerTensor): 56 handle_data = source_t._handle_data # pylint: disable=protected-access 57 else: 58 handle_data = resource_variable_ops.get_resource_handle_data(source_t) 59 if (handle_data is not None 60 and handle_data.is_set 61 and handle_data.shape_and_type): 62 # pylint: disable=protected-access 63 pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, 64 target_t._as_tf_output(), 65 handle_data.SerializeToString()) 66 # pylint: enable=protected-access 67 # Ensure that shapes and dtypes are propagated. 68 shapes, types = zip(*[(pair.shape, pair.dtype) 69 for pair in handle_data.shape_and_type]) 70 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 71 shapes = [[d.size for d in s.dim] 72 if not s.unknown_rank else None for s in shapes] 73 pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 74 target_t._op._graph._c_graph, # pylint: disable=protected-access 75 target_t._as_tf_output(), # pylint: disable=protected-access 76 shapes, ranks, types) 77 78 79@tf_export("custom_gradient") 80def custom_gradient(f): 81 """Decorator to define a function with a custom gradient. 82 83 This decorator allows fine grained control over the gradients of a sequence 84 for operations. This may be useful for multiple reasons, including providing 85 a more efficient or numerically stable gradient for a sequence of operations. 86 87 For example, consider the following function that commonly occurs in the 88 computation of cross entropy and log likelihoods: 89 90 ```python 91 def log1pexp(x): 92 return tf.log(1 + tf.exp(x)) 93 ``` 94 95 Due to numerical instability, the gradient this function evaluated at x=100 is 96 NaN. For example: 97 98 ```python 99 x = tf.constant(100.) 100 y = log1pexp(x) 101 dy = tf.gradients(y, x) # Will be NaN when evaluated. 102 ``` 103 104 The gradient expression can be analytically simplified to provide numerical 105 stability: 106 107 ```python 108 @tf.custom_gradient 109 def log1pexp(x): 110 e = tf.exp(x) 111 def grad(dy): 112 return dy * (1 - 1 / (1 + e)) 113 return tf.log(1 + e), grad 114 ``` 115 116 With this definition, the gradient at x=100 will be correctly evaluated as 117 1.0. 118 119 See also `tf.RegisterGradient` which registers a gradient function for a 120 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows 121 for fine grained control over the gradient computation of a sequence of 122 operations. 123 124 Note that if the decorated function uses `Variable`s, the enclosing variable 125 scope must be using `ResourceVariable`s. 126 127 Args: 128 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: 129 - `x` is a sequence of `Tensor` inputs to the function. 130 - `y` is a `Tensor` or sequence of `Tensor` outputs of applying 131 TensorFlow operations in `f` to `x`. 132 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns 133 a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect 134 to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of 135 `Tensor`s the same size as `y` holding the initial value gradients for 136 each `Tensor` in `y`. In a pure mathematical sense, a vector-argument 137 vector-valued function `f`'s derivatives should be its Jacobian matrix 138 `J`. Here we are expressing the Jacobian `J` as a function `grad_fn` 139 which defines how `J` will transform a vector `grad_ys` when 140 left-multiplied with it (`grad_ys * J`). This functional representation 141 of a matrix is convenient to use for chain-rule calculation 142 (in e.g. the back-propagation algorithm). 143 144 If `f` uses `Variable`s (that are not part of the 145 inputs), i.e. through `get_variable`, then `grad_fn` should have 146 signature `g(*grad_ys, variables=None)`, where `variables` is a list of 147 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where 148 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` 149 with the derivatives of `Tensor`s in `y` with respect to the variables 150 (that is, grad_vars has one Tensor per variable in variables). 151 152 Returns: 153 A function `h(x)` which returns the same value as `f(x)[0]` and whose 154 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. 155 """ 156 157 def decorated(*args, **kwargs): 158 """Decorated function with custom gradient.""" 159 if context.executing_eagerly(): 160 return _eager_mode_decorator(f, *args, **kwargs) 161 else: 162 return _graph_mode_decorator(f, *args, **kwargs) 163 164 return tf_decorator.make_decorator(f, decorated) 165 166 167def _graph_mode_decorator(f, *args, **kwargs): 168 """Implement custom gradient decorator for graph mode.""" 169 # TODO(rsepassi): Add support for kwargs 170 if kwargs: 171 raise ValueError( 172 "The custom_gradient decorator currently supports keywords " 173 "arguments only when eager execution is enabled.") 174 name = "CustomGradient-%s" % ops.uid() 175 args = [ops.convert_to_tensor(x) for x in args] 176 177 # Checking global and local variables attempts to ensure that no non-resource 178 # Variables are added to the graph. 179 current_var_scope = variable_scope.get_variable_scope() 180 before_vars = set(current_var_scope.global_variables() + 181 current_var_scope.local_variables()) 182 with backprop.GradientTape() as tape: 183 result, grad_fn = f(*args) 184 after_vars = set(current_var_scope.global_variables() + 185 current_var_scope.local_variables()) 186 new_vars = after_vars - before_vars 187 for v in new_vars: 188 if not resource_variable_ops.is_resource_variable(v): 189 raise TypeError( 190 "All variables used by a function wrapped with @custom_gradient must " 191 "be `ResourceVariable`s. Ensure that no `variable_scope` is created " 192 "with `use_resource=False`.") 193 # The variables that grad_fn needs to return gradients for are the set of 194 # variables used that are *not* part of the inputs. 195 variables = list(set(tape.watched_variables()) - set(args)) 196 grad_argspec = tf_inspect.getfullargspec(grad_fn) 197 variables_in_signature = ("variables" in grad_argspec.args or 198 grad_argspec.varkw) 199 if variables and not variables_in_signature: 200 raise TypeError("If using @custom_gradient with a function that " 201 "uses variables, then grad_fn must accept a keyword " 202 "argument 'variables'.") 203 if variables_in_signature and not variables: 204 # User seems to intend to use variables but none were captured. 205 if not variable_scope.get_variable_scope().use_resource: 206 raise TypeError("If using @custom_gradient with a function that " 207 "uses variables, the enclosing variable scope must " 208 "have use_resource=True.") 209 else: 210 logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " 211 "no ResourceVariables were used on the forward pass.") 212 flat_result = nest.flatten(result) 213 all_tensors = flat_result + args + variables 214 215 def tape_grad_fn(*result_grads): 216 """Custom grad fn wrapper.""" 217 result_grads = result_grads[:len(flat_result)] 218 if variables: 219 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 220 if len(variable_grads) != len(variables): 221 raise ValueError("Must return gradient for each variable from " 222 "@custom_gradient grad_fn.") 223 else: 224 input_grads = grad_fn(*result_grads) 225 variable_grads = [] 226 227 # Need to return one value per input to the IdentityN, so pad the 228 # gradients of the inputs of the custom_gradient function with the 229 # gradients of the outputs as well. 230 input_grads = nest.flatten(input_grads) 231 return ([None] * len(flat_result)) + input_grads + variable_grads 232 233 @ops.RegisterGradient(name) 234 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable 235 """Custom grad fn wrapper.""" 236 return tape_grad_fn(*result_grads) 237 238 original_tensors = all_tensors 239 with ops.get_default_graph().gradient_override_map({"IdentityN": name}): 240 all_tensors = array_ops.identity_n(all_tensors) 241 242 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] 243 244 # Propagate handle data for happier shape inference for resource variables. 245 for i, t in enumerate(original_tensors): 246 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): 247 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access 248 tape_lib.record_operation( 249 f.__name__, all_tensors, original_tensors, tape_grad_fn) 250 for ot, t in zip(original_tensors, all_tensors): 251 copy_handle_data(ot, t) 252 return nest.pack_sequence_as( 253 structure=result, flat_sequence=all_tensors[:len(flat_result)]) 254 255 256def _eager_mode_decorator(f, *args, **kwargs): 257 """Implement custom gradient decorator for eager mode.""" 258 with backprop.GradientTape() as tape: 259 result, grad_fn = f(*args, **kwargs) 260 all_inputs = list(args) + list(kwargs.values()) 261 # The variables that grad_fn needs to return gradients for are the set of 262 # variables used that are *not* part of the inputs. 263 variables = [v for v in set(tape.watched_variables()) if v not in all_inputs] 264 grad_argspec = tf_inspect.getfullargspec(grad_fn) 265 if (variables and ("variables" not in grad_argspec.args) and 266 not grad_argspec.varkw): 267 raise TypeError("If using @custom_gradient with a function that " 268 "uses variables, then grad_fn must accept a keyword " 269 "argument 'variables'.") 270 flat_result = nest.flatten(result) 271 # TODO(apassos) consider removing the identity below. 272 flat_result = [gen_array_ops.identity(x) for x in flat_result] 273 274 input_tensors = [ops.convert_to_tensor(x) for x 275 in list(args) + list(variables)] 276 arg_count = len(args) 277 def actual_grad_fn(*result_grads): 278 """Custom grad fn wrapper.""" 279 if variables: 280 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 281 if len(variable_grads) != len(variables): 282 raise ValueError("Must return gradient for each variable from " 283 "@custom_gradient grad_fn.") 284 else: 285 input_grads = grad_fn(*result_grads) 286 variable_grads = [] 287 flat_grads = nest.flatten(input_grads) 288 if len(flat_grads) != arg_count: 289 raise ValueError( 290 "custom_gradient function expected to return", arg_count, 291 "gradients but returned", len(flat_grads), "instead.") 292 return nest.flatten(input_grads) + variable_grads 293 294 tape_lib.record_operation(f.__name__, flat_result, input_tensors, 295 actual_grad_fn) 296 flat_result = list(flat_result) 297 return nest.pack_sequence_as(result, flat_result) 298