1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Decorator to overrides the gradient for a function.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.client import pywrap_tf_session 21from tensorflow.python.eager import backprop 22from tensorflow.python.eager import context 23from tensorflow.python.eager import tape as tape_lib 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_array_ops 28from tensorflow.python.ops import op_selector 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.util import nest 33from tensorflow.python.util import tf_decorator 34from tensorflow.python.util import tf_inspect 35from tensorflow.python.util.tf_export import tf_export 36 37 38VAR_OP_TYPES = [ 39 "VariableV2", 40 "VarHandleOp", 41] 42 43 44def copy_handle_data(source_t, target_t): 45 """Copies HandleData for variant and resource type tensors if available. 46 47 The CppShapeInferenceResult::HandleData proto contains information about the 48 shapes and types of the element tensors of resource/variant type tensors. 49 We need to copy this across function boundaries, i.e., when capturing a 50 placeholder or when returning a function tensor as output. If we don't do this 51 the element tensors will have unknown shapes, e.g., if a TensorList variant 52 tensor is captured as a placeholder, elements popped from that list would have 53 unknown shape. 54 55 Args: 56 source_t: The tensor to copy HandleData from. 57 target_t: The tensor to copy HandleData to. 58 """ 59 if (target_t.dtype == dtypes.resource or 60 target_t.dtype == dtypes.variant): 61 if isinstance(source_t, ops.EagerTensor): 62 handle_data = source_t._handle_data # pylint: disable=protected-access 63 else: 64 handle_data = resource_variable_ops.get_resource_handle_data(source_t) 65 if (handle_data is not None 66 and handle_data.is_set 67 and handle_data.shape_and_type): 68 # pylint: disable=protected-access 69 pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph, 70 target_t._as_tf_output(), 71 handle_data.SerializeToString()) 72 # pylint: enable=protected-access 73 # Ensure that shapes and dtypes are propagated. 74 shapes, types = zip(*[(pair.shape, pair.dtype) 75 for pair in handle_data.shape_and_type]) 76 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 77 shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension 78 if not s.unknown_rank else None for s in shapes] 79 pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 80 target_t._op._graph._c_graph, # pylint: disable=protected-access 81 target_t._as_tf_output(), # pylint: disable=protected-access 82 shapes, 83 ranks, 84 types) 85 86 87@tf_export("custom_gradient") 88def custom_gradient(f=None): 89 """Decorator to define a function with a custom gradient. 90 91 This decorator allows fine grained control over the gradients of a sequence 92 for operations. This may be useful for multiple reasons, including providing 93 a more efficient or numerically stable gradient for a sequence of operations. 94 95 For example, consider the following function that commonly occurs in the 96 computation of cross entropy and log likelihoods: 97 98 ```python 99 def log1pexp(x): 100 return tf.math.log(1 + tf.exp(x)) 101 ``` 102 103 Due to numerical instability, the gradient this function evaluated at x=100 is 104 NaN. For example: 105 106 ```python 107 x = tf.constant(100.) 108 y = log1pexp(x) 109 dy = tf.gradients(y, x) # Will be NaN when evaluated. 110 ``` 111 112 The gradient expression can be analytically simplified to provide numerical 113 stability: 114 115 ```python 116 @tf.custom_gradient 117 def log1pexp(x): 118 e = tf.exp(x) 119 def grad(dy): 120 return dy * (1 - 1 / (1 + e)) 121 return tf.math.log(1 + e), grad 122 ``` 123 124 With this definition, the gradient at x=100 will be correctly evaluated as 125 1.0. 126 127 Nesting custom gradients can lead to unintuitive results. The default 128 behavior does not correspond to n-th order derivatives. For example 129 130 ```python 131 @tf.custom_gradient 132 def op(x): 133 y = op1(x) 134 @tf.custom_gradient 135 def grad_fn(dy): 136 gdy = op2(x, y, dy) 137 def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. 138 return op3(x, y, dy, ddy) 139 return gdy, grad_grad_fn 140 return y, grad_fn 141 ``` 142 143 The function `grad_grad_fn` will be calculating the first order gradient 144 of `grad_fn` with respect to `dy`, which is used to generate forward-mode 145 gradient graphs from backward-mode gradient graphs, but is not the same as 146 the second order gradient of `op` with respect to `x`. 147 148 Instead, wrap nested `@tf.custom_gradients` in another function: 149 150 ```python 151 @tf.custom_gradient 152 def op_with_fused_backprop(x): 153 y, x_grad = fused_op(x) 154 def first_order_gradient(dy): 155 @tf.custom_gradient 156 def first_order_custom(unused_x): 157 def second_order_and_transpose(ddy): 158 return second_order_for_x(...), gradient_wrt_dy(...) 159 return x_grad, second_order_and_transpose 160 return dy * first_order_custom(x) 161 return y, first_order_gradient 162 ``` 163 164 Additional arguments to the inner `@tf.custom_gradient`-decorated function 165 control the expected return values of the innermost function. 166 167 See also `tf.RegisterGradient` which registers a gradient function for a 168 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows 169 for fine grained control over the gradient computation of a sequence of 170 operations. 171 172 Note that if the decorated function uses `Variable`s, the enclosing variable 173 scope must be using `ResourceVariable`s. 174 175 Args: 176 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: 177 - `x` is a sequence of `Tensor` inputs to the function. 178 - `y` is a `Tensor` or sequence of `Tensor` outputs of applying 179 TensorFlow operations in `f` to `x`. 180 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns 181 a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect 182 to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of 183 `Tensor`s the same size as `y` holding the initial value gradients for 184 each `Tensor` in `y`. In a pure mathematical sense, a vector-argument 185 vector-valued function `f`'s derivatives should be its Jacobian matrix 186 `J`. Here we are expressing the Jacobian `J` as a function `grad_fn` 187 which defines how `J` will transform a vector `grad_ys` when 188 left-multiplied with it (`grad_ys * J`). This functional representation 189 of a matrix is convenient to use for chain-rule calculation 190 (in e.g. the back-propagation algorithm). 191 192 If `f` uses `Variable`s (that are not part of the 193 inputs), i.e. through `get_variable`, then `grad_fn` should have 194 signature `g(*grad_ys, variables=None)`, where `variables` is a list of 195 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where 196 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` 197 with the derivatives of `Tensor`s in `y` with respect to the variables 198 (that is, grad_vars has one Tensor per variable in variables). 199 200 Returns: 201 A function `h(x)` which returns the same value as `f(x)[0]` and whose 202 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. 203 """ 204 205 if f is None: 206 return lambda f: custom_gradient(f=f) 207 208 @Bind.decorator 209 def decorated(wrapped, args, kwargs): 210 """Decorated function with custom gradient.""" 211 if context.executing_eagerly(): 212 return _eager_mode_decorator(wrapped, args, kwargs) 213 else: 214 return _graph_mode_decorator(wrapped, args, kwargs) 215 216 return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter 217 218 219class Bind(object): 220 """When called evaluates `d(f, args, kwargs)` but supports binding `f`. 221 222 >>> @Bind.decorator 223 ... def my_decorator(f, args, kwargs): 224 ... print("my_decorator called with", args, kwargs) 225 ... return f(*args, **kwargs) 226 227 >>> class Foo(object): 228 ... @my_decorator 229 ... def bar(self, a, b, c): 230 ... return a * b * c 231 232 >>> Foo.bar(None, 1, 2, c=3) 233 my_decorator called with (None, 1, 2) {'c': 3} 234 6 235 236 >>> foo = Foo() 237 >>> foo.bar(1, 2, c=3) 238 my_decorator called with (1, 2) {'c': 3} 239 6 240 """ 241 242 @classmethod 243 def decorator(cls, d): 244 return lambda f: Bind(f, d) 245 246 def __init__(self, f, d): 247 self._f = f 248 self._d = d 249 250 def __get__(self, instance, owner): 251 if instance is not None: 252 f = self._f.__get__(instance, owner) 253 return tf_decorator.make_decorator(f, Bind(f, self._d)) 254 else: 255 return self 256 257 def __call__(self, *a, **k): 258 return self._d(self._f, a, k) 259 260 261def get_variable_by_name(var_name): 262 """Given a variable name, retrieves a handle on the tensorflow Variable.""" 263 264 candidate_vars = ops.get_collection( 265 ops.GraphKeys.GLOBAL_VARIABLES, scope="{}:0".format(var_name)) 266 if len(candidate_vars) >= 1: 267 # Filter out non-trainable variables. 268 candidate_vars = [v for v in candidate_vars if v.trainable] 269 else: 270 raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) 271 272 if len(candidate_vars) == 1: 273 return candidate_vars[0] 274 elif len(candidate_vars) > 1: 275 raise ValueError( 276 "Unsuccessful at finding trainable variable {}. " 277 "Number of candidates: {}. " 278 "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) 279 else: 280 # The variable is not trainable. 281 return None 282 283 284def get_dependent_variables(input_ops, output_ops): 285 """Finds variables involved in the subgraph b/w input_ops and output_ops.""" 286 287 # avoids the edge-case when input_ops == output_ops. 288 output_ops = nest.map_structure(gen_array_ops.identity, output_ops) 289 inbetween_ops = op_selector.get_backward_walk_ops( 290 seed_ops=nest.flatten(output_ops), 291 stop_at_ts=nest.flatten(input_ops), 292 inclusive=False, 293 only_differentiable=True) 294 var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) 295 var_names = (op.name for op in var_ops) 296 tf_vars = (get_variable_by_name(var_name) for var_name in var_names) 297 tf_vars = [v for v in tf_vars if v is not None] 298 return tf_vars 299 300 301def _graph_mode_decorator(f, args, kwargs): 302 """Implement custom gradient decorator for graph mode.""" 303 # TODO(rsepassi): Add support for kwargs 304 if kwargs: 305 raise ValueError( 306 "The custom_gradient decorator currently supports keywords " 307 "arguments only when eager execution is enabled.") 308 name = "CustomGradient-%s" % ops.uid() 309 args = [ops.convert_to_tensor(x) for x in args] 310 311 # Checking global and local variables attempts to ensure that no non-resource 312 # Variables are added to the graph. 313 current_var_scope = variable_scope.get_variable_scope() 314 before_vars = set( 315 [v.experimental_ref() for v in current_var_scope.global_variables() + 316 current_var_scope.local_variables()]) 317 with backprop.GradientTape() as tape: 318 result, grad_fn = f(*args) 319 after_vars = set( 320 [v.experimental_ref() for v in current_var_scope.global_variables() + 321 current_var_scope.local_variables()]) 322 new_vars = after_vars - before_vars 323 new_vars_list = [v.deref() for v in new_vars] 324 for v in new_vars_list: 325 if not resource_variable_ops.is_resource_variable(v): 326 raise TypeError( 327 "All variables used by a function wrapped with @custom_gradient must " 328 "be `ResourceVariable`s. Ensure that no `variable_scope` is created " 329 "with `use_resource=False`.") 330 # The variables that grad_fn needs to return gradients for are the set of 331 # variables used that are *not* part of the inputs. 332 inputs = args 333 variables_in_tape = frozenset([ 334 v.experimental_ref() for v in tape.watched_variables() 335 ]) - frozenset(v.experimental_ref() for v in inputs) 336 variables_in_subgraph = frozenset([ 337 v.experimental_ref() 338 for v in get_dependent_variables(input_ops=inputs, output_ops=result) 339 ]) 340 variables = list( 341 [v.deref() for v in variables_in_subgraph.union(variables_in_tape)]) 342 343 grad_argspec = tf_inspect.getfullargspec(grad_fn) 344 variables_in_signature = ("variables" in grad_argspec.args or 345 grad_argspec.varkw) 346 if variables and not variables_in_signature: 347 raise TypeError("If using @custom_gradient with a function that " 348 "uses variables, then grad_fn must accept a keyword " 349 "argument 'variables'.") 350 if variables_in_signature and not variables: 351 # User seems to intend to use variables but none were captured. 352 if not variable_scope.get_variable_scope().use_resource: 353 raise TypeError("If using @custom_gradient with a function that " 354 "uses variables, the enclosing variable scope must " 355 "have use_resource=True.") 356 else: 357 logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " 358 "no ResourceVariables were used on the forward pass.") 359 flat_result = nest.flatten(result) 360 flat_result_len = len(flat_result) 361 362 all_tensors = flat_result + args + variables 363 364 def tape_grad_fn(*result_grads): 365 """Custom grad fn wrapper.""" 366 result_grads = result_grads[:flat_result_len] 367 if variables: 368 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 369 if len(variable_grads) != len(variables): 370 raise ValueError("Must return gradient for each variable from " 371 "@custom_gradient grad_fn.") 372 else: 373 input_grads = grad_fn(*result_grads) 374 variable_grads = [] 375 376 # Need to return one value per input to the IdentityN, so pad the 377 # gradients of the inputs of the custom_gradient function with the 378 # gradients of the outputs as well. 379 input_grads = nest.flatten(input_grads) 380 return ([None] * flat_result_len) + input_grads + variable_grads 381 382 @ops.RegisterGradient(name) 383 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable 384 """Custom grad fn wrapper.""" 385 return tape_grad_fn(*result_grads) 386 387 original_tensors = all_tensors 388 with ops.get_default_graph().gradient_override_map({"IdentityN": name}): 389 all_tensors = array_ops.identity_n(all_tensors) 390 391 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] 392 393 # Propagate handle data for happier shape inference for resource variables. 394 for i, t in enumerate(original_tensors): 395 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): 396 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access 397 tape_lib.record_operation( 398 f.__name__, all_tensors, original_tensors, tape_grad_fn) 399 for ot, t in zip(original_tensors, all_tensors): 400 copy_handle_data(ot, t) 401 return nest.pack_sequence_as( 402 structure=result, flat_sequence=all_tensors[:flat_result_len]) 403 404 405def _eager_mode_decorator(f, args, kwargs): 406 """Implement custom gradient decorator for eager mode.""" 407 with backprop.GradientTape() as tape: 408 result, grad_fn = f(*args, **kwargs) 409 all_inputs = list(args) + list(kwargs.values()) 410 # The variables that grad_fn needs to return gradients for are the set of 411 # variables used that are *not* part of the inputs. 412 variables = [ 413 v.deref() # pylint: disable=g-complex-comprehension 414 for v in set(v.experimental_ref() for v in tape.watched_variables()) 415 if all(v.deref() is not i for i in all_inputs) 416 ] 417 grad_argspec = tf_inspect.getfullargspec(grad_fn) 418 if (variables and ("variables" not in grad_argspec.args) and 419 not grad_argspec.varkw): 420 raise TypeError("If using @custom_gradient with a function that " 421 "uses variables, then grad_fn must accept a keyword " 422 "argument 'variables'.") 423 flat_result = nest.flatten(result) 424 # TODO(apassos) consider removing the identity below. 425 flat_result = [gen_array_ops.identity(x) for x in flat_result] 426 427 input_tensors = [ops.convert_to_tensor(x) for x 428 in list(args) + list(variables)] 429 430 recorded_inputs = input_tensors 431 arg_count = len(args) 432 433 def actual_grad_fn(*result_grads): 434 """Custom grad fn wrapper.""" 435 if variables: 436 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 437 if len(variable_grads) != len(variables): 438 raise ValueError("Must return gradient for each variable from " 439 "@custom_gradient grad_fn.") 440 else: 441 input_grads = grad_fn(*result_grads) 442 variable_grads = [] 443 flat_grads = nest.flatten(input_grads) 444 if len(flat_grads) != arg_count: 445 raise ValueError( 446 "custom_gradient function expected to return", arg_count, 447 "gradients but returned", len(flat_grads), "instead.") 448 return nest.flatten(input_grads) + variable_grads 449 450 tape_lib.record_operation(f.__name__, flat_result, recorded_inputs, 451 actual_grad_fn) 452 flat_result = list(flat_result) 453 return nest.pack_sequence_as(result, flat_result) 454 455 456@tf_export("recompute_grad") 457def recompute_grad(f): 458 """An eager-compatible version of recompute_grad. 459 460 For f(*args, **kwargs), this supports gradients with respect to args or 461 kwargs, but kwargs are currently only supported in eager-mode. 462 Note that for keras layer and model objects, this is handled automatically. 463 464 Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not 465 be able to access the member variables of that object, because `g` returns 466 through the wrapper function `inner`. When recomputing gradients through 467 objects that inherit from keras, we suggest keeping a reference to the 468 underlying object around for the purpose of accessing these variables. 469 470 Args: 471 f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. 472 473 Returns: 474 A function `g` that wraps `f`, but which recomputes `f` on the backwards 475 pass of a gradient call. 476 """ 477 # TODO(cdfreeman) Add is_recomputing functionality from graph mode version 478 479 @custom_gradient 480 def inner(*args, **kwargs): 481 """Inner function closure for calculating gradients.""" 482 current_var_scope = variable_scope.get_variable_scope() 483 484 result = f(*args, **kwargs) 485 486 def grad(*dresult, **grad_kwargs): 487 """Gradient function calculation for inner function.""" 488 variables = grad_kwargs.get("variables") 489 with backprop.GradientTape() as t: 490 id_args = [gen_array_ops.identity(x) for x in args] 491 t.watch(id_args) 492 if variables is not None: 493 t.watch(variables) 494 with ops.control_dependencies(dresult): 495 with variable_scope.variable_scope(current_var_scope): 496 result = f(*id_args, **kwargs) 497 kw_vars = [] 498 if variables is not None: 499 kw_vars = list(variables) 500 grads = t.gradient( 501 result, list(id_args) + kw_vars, output_gradients=dresult) 502 return grads[:len(id_args)], grads[len(id_args):] 503 504 return result, grad 505 506 return inner 507 508 509@tf_export("grad_pass_through") 510def grad_pass_through(f): 511 """Creates a grad-pass-through op with the forward behavior provided in f. 512 513 Use this function to wrap any op, maintaining its behavior in the forward 514 pass, but replacing the original op in the backward graph with an identity. 515 For example: 516 517 ```python 518 x = tf.Variable(1.0, name="x") 519 z = tf.Variable(3.0, name="z") 520 521 with tf.GradientTape() as tape: 522 # y will evaluate to 9.0 523 y = tf.grad_pass_through(x.assign)(z**2) 524 # grads will evaluate to 6.0 525 grads = tape.gradient(y, z) 526 ``` 527 528 Another example is a 'differentiable' moving average approximation, where 529 gradients are allowed to flow into the last value fed to the moving average, 530 but the moving average is still used for the forward pass: 531 532 ```python 533 x = ... # Some scalar value 534 # A moving average object, we don't need to know how this is implemented 535 moving_average = MovingAverage() 536 with backprop.GradientTape() as tape: 537 # mavg_x will evaluate to the current running average value 538 mavg_x = tf.grad_pass_through(moving_average)(x) 539 grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 540 ``` 541 542 Args: 543 f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` 544 outputs. 545 546 Returns: 547 A function `h(x)` which returns the same values as `f(x)` and whose 548 gradients are the same as those of an identity function. 549 """ 550 @custom_gradient 551 def _grad_pass_through_op(*args, **kwargs): 552 def grad(*args, **kwargs): 553 variables = kwargs.get("variables") 554 if variables is not None: 555 # Variables involved in the wrapped op will not receive gradients. 556 return args, [None] * len(variables) 557 return args 558 return f(*args, **kwargs), grad 559 return tf_decorator.make_decorator(f, _grad_pass_through_op) 560