• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.client import pywrap_tf_session
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import context
23from tensorflow.python.eager import tape as tape_lib
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_array_ops
28from tensorflow.python.ops import op_selector
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util import tf_inspect
35from tensorflow.python.util.tf_export import tf_export
36
37
38VAR_OP_TYPES = [
39    "VariableV2",
40    "VarHandleOp",
41]
42
43
44def copy_handle_data(source_t, target_t):
45  """Copies HandleData for variant and resource type tensors if available.
46
47  The CppShapeInferenceResult::HandleData proto contains information about the
48  shapes and types of the element tensors of resource/variant type tensors.
49  We need to copy this across function boundaries, i.e., when capturing a
50  placeholder or when returning a function tensor as output. If we don't do this
51  the element tensors will have unknown shapes, e.g., if a TensorList variant
52  tensor is captured as a placeholder, elements popped from that list would have
53  unknown shape.
54
55  Args:
56    source_t: The tensor to copy HandleData from.
57    target_t: The tensor to copy HandleData to.
58  """
59  if (target_t.dtype == dtypes.resource or
60      target_t.dtype == dtypes.variant):
61    if isinstance(source_t, ops.EagerTensor):
62      handle_data = source_t._handle_data  # pylint: disable=protected-access
63    else:
64      handle_data = resource_variable_ops.get_resource_handle_data(source_t)
65    if (handle_data is not None
66        and handle_data.is_set
67        and handle_data.shape_and_type):
68      # pylint: disable=protected-access
69      pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
70                                              target_t._as_tf_output(),
71                                              handle_data.SerializeToString())
72      # pylint: enable=protected-access
73      # Ensure that shapes and dtypes are propagated.
74      shapes, types = zip(*[(pair.shape, pair.dtype)
75                            for pair in handle_data.shape_and_type])
76      ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
77      shapes = [[d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
78                if not s.unknown_rank else None for s in shapes]
79      pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
80          target_t._op._graph._c_graph,  # pylint: disable=protected-access
81          target_t._as_tf_output(),  # pylint: disable=protected-access
82          shapes,
83          ranks,
84          types)
85
86
87@tf_export("custom_gradient")
88def custom_gradient(f=None):
89  """Decorator to define a function with a custom gradient.
90
91  This decorator allows fine grained control over the gradients of a sequence
92  for operations.  This may be useful for multiple reasons, including providing
93  a more efficient or numerically stable gradient for a sequence of operations.
94
95  For example, consider the following function that commonly occurs in the
96  computation of cross entropy and log likelihoods:
97
98  ```python
99  def log1pexp(x):
100    return tf.math.log(1 + tf.exp(x))
101  ```
102
103  Due to numerical instability, the gradient this function evaluated at x=100 is
104  NaN.  For example:
105
106  ```python
107  x = tf.constant(100.)
108  y = log1pexp(x)
109  dy = tf.gradients(y, x) # Will be NaN when evaluated.
110  ```
111
112  The gradient expression can be analytically simplified to provide numerical
113  stability:
114
115  ```python
116  @tf.custom_gradient
117  def log1pexp(x):
118    e = tf.exp(x)
119    def grad(dy):
120      return dy * (1 - 1 / (1 + e))
121    return tf.math.log(1 + e), grad
122  ```
123
124  With this definition, the gradient at x=100 will be correctly evaluated as
125  1.0.
126
127  Nesting custom gradients can lead to unintuitive results. The default
128  behavior does not correspond to n-th order derivatives. For example
129
130  ```python
131  @tf.custom_gradient
132  def op(x):
133    y = op1(x)
134    @tf.custom_gradient
135    def grad_fn(dy):
136      gdy = op2(x, y, dy)
137      def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
138        return op3(x, y, dy, ddy)
139      return gdy, grad_grad_fn
140    return y, grad_fn
141  ```
142
143  The function `grad_grad_fn` will be calculating the first order gradient
144  of `grad_fn` with respect to `dy`, which is used to generate forward-mode
145  gradient graphs from backward-mode gradient graphs, but is not the same as
146  the second order gradient of `op` with respect to `x`.
147
148  Instead, wrap nested `@tf.custom_gradients` in another function:
149
150  ```python
151  @tf.custom_gradient
152  def op_with_fused_backprop(x):
153    y, x_grad = fused_op(x)
154    def first_order_gradient(dy):
155      @tf.custom_gradient
156      def first_order_custom(unused_x):
157        def second_order_and_transpose(ddy):
158          return second_order_for_x(...), gradient_wrt_dy(...)
159        return x_grad, second_order_and_transpose
160      return dy * first_order_custom(x)
161    return y, first_order_gradient
162  ```
163
164  Additional arguments to the inner `@tf.custom_gradient`-decorated function
165  control the expected return values of the innermost function.
166
167  See also `tf.RegisterGradient` which registers a gradient function for a
168  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
169  for fine grained control over the gradient computation of a sequence of
170  operations.
171
172  Note that if the decorated function uses `Variable`s, the enclosing variable
173  scope must be using `ResourceVariable`s.
174
175  Args:
176    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
177       - `x` is a sequence of `Tensor` inputs to the function.
178       - `y` is a `Tensor` or sequence of `Tensor` outputs of applying
179         TensorFlow operations in `f` to `x`.
180       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
181         a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
182         to the `Tensor`s in `x`.  `grad_ys` is a `Tensor` or sequence of
183         `Tensor`s the same size as `y` holding the initial value gradients for
184         each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
185         vector-valued function `f`'s derivatives should be its Jacobian matrix
186         `J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
187         which defines how `J` will transform a vector `grad_ys` when
188         left-multiplied with it (`grad_ys * J`). This functional representation
189         of a matrix is convenient to use for chain-rule calculation
190         (in e.g. the back-propagation algorithm).
191
192         If `f` uses `Variable`s (that are not part of the
193         inputs), i.e. through `get_variable`, then `grad_fn` should have
194         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
195         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
196         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
197         with the derivatives of `Tensor`s in `y` with respect to the variables
198         (that is, grad_vars has one Tensor per variable in variables).
199
200  Returns:
201    A function `h(x)` which returns the same value as `f(x)[0]` and whose
202    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
203  """
204
205  if f is None:
206    return lambda f: custom_gradient(f=f)
207
208  @Bind.decorator
209  def decorated(wrapped, args, kwargs):
210    """Decorated function with custom gradient."""
211    if context.executing_eagerly():
212      return _eager_mode_decorator(wrapped, args, kwargs)
213    else:
214      return _graph_mode_decorator(wrapped, args, kwargs)
215
216  return tf_decorator.make_decorator(f, decorated(f))  # pylint: disable=no-value-for-parameter
217
218
219class Bind(object):
220  """When called evaluates `d(f, args, kwargs)` but supports binding `f`.
221
222  >>> @Bind.decorator
223  ... def my_decorator(f, args, kwargs):
224  ...   print("my_decorator called with", args, kwargs)
225  ...   return f(*args, **kwargs)
226
227  >>> class Foo(object):
228  ...   @my_decorator
229  ...   def bar(self, a, b, c):
230  ...     return a * b * c
231
232  >>> Foo.bar(None, 1, 2, c=3)
233  my_decorator called with (None, 1, 2) {'c': 3}
234  6
235
236  >>> foo = Foo()
237  >>> foo.bar(1, 2, c=3)
238  my_decorator called with (1, 2) {'c': 3}
239  6
240  """
241
242  @classmethod
243  def decorator(cls, d):
244    return lambda f: Bind(f, d)
245
246  def __init__(self, f, d):
247    self._f = f
248    self._d = d
249
250  def __get__(self, instance, owner):
251    if instance is not None:
252      f = self._f.__get__(instance, owner)
253      return tf_decorator.make_decorator(f, Bind(f, self._d))
254    else:
255      return self
256
257  def __call__(self, *a, **k):
258    return self._d(self._f, a, k)
259
260
261def get_variable_by_name(var_name):
262  """Given a variable name, retrieves a handle on the tensorflow Variable."""
263
264  candidate_vars = ops.get_collection(
265      ops.GraphKeys.GLOBAL_VARIABLES, scope="{}:0".format(var_name))
266  if len(candidate_vars) >= 1:
267    # Filter out non-trainable variables.
268    candidate_vars = [v for v in candidate_vars if v.trainable]
269  else:
270    raise ValueError("Unsuccessful at finding variable {}.".format(var_name))
271
272  if len(candidate_vars) == 1:
273    return candidate_vars[0]
274  elif len(candidate_vars) > 1:
275    raise ValueError(
276        "Unsuccessful at finding trainable variable {}. "
277        "Number of candidates: {}. "
278        "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars))
279  else:
280    # The variable is not trainable.
281    return None
282
283
284def get_dependent_variables(input_ops, output_ops):
285  """Finds variables involved in the subgraph b/w input_ops and output_ops."""
286
287  # avoids the edge-case when input_ops == output_ops.
288  output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
289  inbetween_ops = op_selector.get_backward_walk_ops(
290      seed_ops=nest.flatten(output_ops),
291      stop_at_ts=nest.flatten(input_ops),
292      inclusive=False,
293      only_differentiable=True)
294  var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
295  var_names = (op.name for op in var_ops)
296  tf_vars = (get_variable_by_name(var_name) for var_name in var_names)
297  tf_vars = [v for v in tf_vars if v is not None]
298  return tf_vars
299
300
301def _graph_mode_decorator(f, args, kwargs):
302  """Implement custom gradient decorator for graph mode."""
303  # TODO(rsepassi): Add support for kwargs
304  if kwargs:
305    raise ValueError(
306        "The custom_gradient decorator currently supports keywords "
307        "arguments only when eager execution is enabled.")
308  name = "CustomGradient-%s" % ops.uid()
309  args = [ops.convert_to_tensor(x) for x in args]
310
311  # Checking global and local variables attempts to ensure that no non-resource
312  # Variables are added to the graph.
313  current_var_scope = variable_scope.get_variable_scope()
314  before_vars = set(
315      [v.experimental_ref() for v in current_var_scope.global_variables() +
316       current_var_scope.local_variables()])
317  with backprop.GradientTape() as tape:
318    result, grad_fn = f(*args)
319  after_vars = set(
320      [v.experimental_ref() for v in current_var_scope.global_variables() +
321       current_var_scope.local_variables()])
322  new_vars = after_vars - before_vars
323  new_vars_list = [v.deref() for v in new_vars]
324  for v in new_vars_list:
325    if not resource_variable_ops.is_resource_variable(v):
326      raise TypeError(
327          "All variables used by a function wrapped with @custom_gradient must "
328          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
329          "with `use_resource=False`.")
330  # The variables that grad_fn needs to return gradients for are the set of
331  # variables used that are *not* part of the inputs.
332  inputs = args
333  variables_in_tape = frozenset([
334      v.experimental_ref() for v in tape.watched_variables()
335  ]) - frozenset(v.experimental_ref() for v in inputs)
336  variables_in_subgraph = frozenset([
337      v.experimental_ref()
338      for v in get_dependent_variables(input_ops=inputs, output_ops=result)
339  ])
340  variables = list(
341      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])
342
343  grad_argspec = tf_inspect.getfullargspec(grad_fn)
344  variables_in_signature = ("variables" in grad_argspec.args or
345                            grad_argspec.varkw)
346  if variables and not variables_in_signature:
347    raise TypeError("If using @custom_gradient with a function that "
348                    "uses variables, then grad_fn must accept a keyword "
349                    "argument 'variables'.")
350  if variables_in_signature and not variables:
351    # User seems to intend to use variables but none were captured.
352    if not variable_scope.get_variable_scope().use_resource:
353      raise TypeError("If using @custom_gradient with a function that "
354                      "uses variables, the enclosing variable scope must "
355                      "have use_resource=True.")
356    else:
357      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
358                   "no ResourceVariables were used on the forward pass.")
359  flat_result = nest.flatten(result)
360  flat_result_len = len(flat_result)
361
362  all_tensors = flat_result + args + variables
363
364  def tape_grad_fn(*result_grads):
365    """Custom grad fn wrapper."""
366    result_grads = result_grads[:flat_result_len]
367    if variables:
368      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
369      if len(variable_grads) != len(variables):
370        raise ValueError("Must return gradient for each variable from "
371                         "@custom_gradient grad_fn.")
372    else:
373      input_grads = grad_fn(*result_grads)
374      variable_grads = []
375
376    # Need to return one value per input to the IdentityN, so pad the
377    # gradients of the inputs of the custom_gradient function with the
378    # gradients of the outputs as well.
379    input_grads = nest.flatten(input_grads)
380    return ([None] * flat_result_len) + input_grads + variable_grads
381
382  @ops.RegisterGradient(name)
383  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
384    """Custom grad fn wrapper."""
385    return tape_grad_fn(*result_grads)
386
387  original_tensors = all_tensors
388  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
389    all_tensors = array_ops.identity_n(all_tensors)
390
391  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
392
393  # Propagate handle data for happier shape inference for resource variables.
394  for i, t in enumerate(original_tensors):
395    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
396      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
397  tape_lib.record_operation(
398      f.__name__, all_tensors, original_tensors, tape_grad_fn)
399  for ot, t in zip(original_tensors, all_tensors):
400    copy_handle_data(ot, t)
401  return nest.pack_sequence_as(
402      structure=result, flat_sequence=all_tensors[:flat_result_len])
403
404
405def _eager_mode_decorator(f, args, kwargs):
406  """Implement custom gradient decorator for eager mode."""
407  with backprop.GradientTape() as tape:
408    result, grad_fn = f(*args, **kwargs)
409  all_inputs = list(args) + list(kwargs.values())
410  # The variables that grad_fn needs to return gradients for are the set of
411  # variables used that are *not* part of the inputs.
412  variables = [
413      v.deref()  # pylint: disable=g-complex-comprehension
414      for v in set(v.experimental_ref() for v in tape.watched_variables())
415      if all(v.deref() is not i for i in all_inputs)
416  ]
417  grad_argspec = tf_inspect.getfullargspec(grad_fn)
418  if (variables and ("variables" not in grad_argspec.args) and
419      not grad_argspec.varkw):
420    raise TypeError("If using @custom_gradient with a function that "
421                    "uses variables, then grad_fn must accept a keyword "
422                    "argument 'variables'.")
423  flat_result = nest.flatten(result)
424  # TODO(apassos) consider removing the identity below.
425  flat_result = [gen_array_ops.identity(x) for x in flat_result]
426
427  input_tensors = [ops.convert_to_tensor(x) for x
428                   in list(args) + list(variables)]
429
430  recorded_inputs = input_tensors
431  arg_count = len(args)
432
433  def actual_grad_fn(*result_grads):
434    """Custom grad fn wrapper."""
435    if variables:
436      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
437      if len(variable_grads) != len(variables):
438        raise ValueError("Must return gradient for each variable from "
439                         "@custom_gradient grad_fn.")
440    else:
441      input_grads = grad_fn(*result_grads)
442      variable_grads = []
443    flat_grads = nest.flatten(input_grads)
444    if len(flat_grads) != arg_count:
445      raise ValueError(
446          "custom_gradient function expected to return", arg_count,
447          "gradients but returned", len(flat_grads), "instead.")
448    return nest.flatten(input_grads) + variable_grads
449
450  tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
451                            actual_grad_fn)
452  flat_result = list(flat_result)
453  return nest.pack_sequence_as(result, flat_result)
454
455
456@tf_export("recompute_grad")
457def recompute_grad(f):
458  """An eager-compatible version of recompute_grad.
459
460  For f(*args, **kwargs), this supports gradients with respect to args or
461  kwargs, but kwargs are currently only supported in eager-mode.
462  Note that for keras layer and model objects, this is handled automatically.
463
464  Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
465  be able to access the member variables of that object, because `g` returns
466  through the wrapper function `inner`.  When recomputing gradients through
467  objects that inherit from keras, we suggest keeping a reference to the
468  underlying object around for the purpose of accessing these variables.
469
470  Args:
471    f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
472
473  Returns:
474   A function `g` that wraps `f`, but which recomputes `f` on the backwards
475   pass of a gradient call.
476  """
477  # TODO(cdfreeman) Add is_recomputing functionality from graph mode version
478
479  @custom_gradient
480  def inner(*args, **kwargs):
481    """Inner function closure for calculating gradients."""
482    current_var_scope = variable_scope.get_variable_scope()
483
484    result = f(*args, **kwargs)
485
486    def grad(*dresult, **grad_kwargs):
487      """Gradient function calculation for inner function."""
488      variables = grad_kwargs.get("variables")
489      with backprop.GradientTape() as t:
490        id_args = [gen_array_ops.identity(x) for x in args]
491        t.watch(id_args)
492        if variables is not None:
493          t.watch(variables)
494        with ops.control_dependencies(dresult):
495          with variable_scope.variable_scope(current_var_scope):
496            result = f(*id_args, **kwargs)
497      kw_vars = []
498      if variables is not None:
499        kw_vars = list(variables)
500      grads = t.gradient(
501          result, list(id_args) + kw_vars, output_gradients=dresult)
502      return grads[:len(id_args)], grads[len(id_args):]
503
504    return result, grad
505
506  return inner
507
508
509@tf_export("grad_pass_through")
510def grad_pass_through(f):
511  """Creates a grad-pass-through op with the forward behavior provided in f.
512
513  Use this function to wrap any op, maintaining its behavior in the forward
514  pass, but replacing the original op in the backward graph with an identity.
515  For example:
516
517  ```python
518  x = tf.Variable(1.0, name="x")
519  z = tf.Variable(3.0, name="z")
520
521  with tf.GradientTape() as tape:
522    # y will evaluate to 9.0
523    y = tf.grad_pass_through(x.assign)(z**2)
524  # grads will evaluate to 6.0
525  grads = tape.gradient(y, z)
526  ```
527
528  Another example is a 'differentiable' moving average approximation, where
529  gradients are allowed to flow into the last value fed to the moving average,
530  but the moving average is still used for the forward pass:
531
532  ```python
533  x = ... # Some scalar value
534  # A moving average object, we don't need to know how this is implemented
535  moving_average = MovingAverage()
536  with backprop.GradientTape() as tape:
537    # mavg_x will evaluate to the current running average value
538    mavg_x = tf.grad_pass_through(moving_average)(x)
539  grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
540  ```
541
542  Args:
543    f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
544      outputs.
545
546  Returns:
547   A function `h(x)` which returns the same values as `f(x)` and whose
548   gradients are the same as those of an identity function.
549  """
550  @custom_gradient
551  def _grad_pass_through_op(*args, **kwargs):
552    def grad(*args, **kwargs):
553      variables = kwargs.get("variables")
554      if variables is not None:
555        # Variables involved in the wrapped op will not receive gradients.
556        return args, [None] * len(variables)
557      return args
558    return f(*args, **kwargs), grad
559  return tf_decorator.make_decorator(f, _grad_pass_through_op)
560