• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Jacobian ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import gradients_impl as gradient_ops
24from tensorflow.python.ops.parallel_for import control_flow_ops
25from tensorflow.python.util import nest
26
27
28def jacobian(output, inputs, use_pfor=True, parallel_iterations=None):
29  """Computes jacobian of `output` w.r.t. `inputs`.
30
31  Args:
32    output: A tensor.
33    inputs: A tensor or a nested structure of tensor objects.
34    use_pfor: If true, uses pfor for computing the jacobian. Else uses
35      tf.while_loop.
36    parallel_iterations: A knob to control how many iterations and dispatched in
37      parallel. This knob can be used to control the total memory usage.
38
39  Returns:
40    A tensor or a nested structure of tensors with the same structure as
41    `inputs`. Each entry is the jacobian of `output` w.r.t. to the corresponding
42    value in `inputs`. If output has shape [y_1, ..., y_n] and inputs_i has
43    shape [x_1, ..., x_m], the corresponding jacobian has shape
44    [y_1, ..., y_n, x_1, ..., x_m]. Note that in cases where the gradient is
45    sparse (IndexedSlices), jacobian function currently makes it dense and
46    returns a Tensor instead. This may change in the future.
47  """
48  flat_inputs = nest.flatten(inputs)
49  output_tensor_shape = output.shape
50  output_shape = array_ops.shape(output)
51  output = array_ops.reshape(output, [-1])
52
53  def loop_fn(i):
54    y = array_ops.gather(output, i)
55    return gradient_ops.gradients(y, flat_inputs)
56
57  try:
58    output_size = int(output.shape[0])
59  except TypeError:
60    output_size = array_ops.shape(output)[0]
61
62  if use_pfor:
63    pfor_outputs = control_flow_ops.pfor(
64        loop_fn, output_size, parallel_iterations=parallel_iterations)
65  else:
66    pfor_outputs = control_flow_ops.for_loop(
67        loop_fn,
68        [output.dtype] * len(flat_inputs),
69        output_size,
70        parallel_iterations=parallel_iterations)
71
72  for i, out in enumerate(pfor_outputs):
73    if isinstance(out, ops.Tensor):
74      new_shape = array_ops.concat(
75          [output_shape, array_ops.shape(out)[1:]], axis=0)
76      out = array_ops.reshape(out, new_shape)
77      out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
78      pfor_outputs[i] = out
79
80  return nest.pack_sequence_as(inputs, pfor_outputs)
81
82
83def batch_jacobian(output, inp, use_pfor=True, parallel_iterations=None):
84  """Computes and stacks jacobians of `output[i,...]` w.r.t. `input[i,...]`.
85
86  e.g.
87  x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
88  y = x * x
89  jacobian = batch_jacobian(y, x)
90  # => [[[2,  0], [0,  4]], [[6,  0], [0,  8]]]
91
92  Args:
93    output: A tensor with shape [b, y1, ..., y_n]. `output[i,...]` should
94      only depend on `inp[i,...]`.
95    inp: A tensor with shape [b, x1, ..., x_m]
96    use_pfor: If true, uses pfor for computing the Jacobian. Else uses a
97      tf.while_loop.
98    parallel_iterations: A knob to control how many iterations are vectorized
99      and dispatched in parallel. The default value of None, when use_pfor is
100      true, corresponds to vectorizing all the iterations. When use_pfor is
101      false, the default value of None corresponds to parallel_iterations=10.
102      This knob can be used to control the total memory usage.
103
104  Returns:
105    A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
106    is the jacobian of `output[i, ...]` w.r.t. `inp[i, ...]`, i.e. stacked
107    per-example jacobians.
108
109  Raises:
110    ValueError: if first dimension of `output` and `inp` do not match.
111  """
112  output_shape = output.shape
113  if not output_shape[0].is_compatible_with(inp.shape[0]):
114    raise ValueError("Need first dimension of output shape (%s) and inp shape "
115                     "(%s) to match." % (output.shape, inp.shape))
116  if output_shape.is_fully_defined():
117    batch_size = int(output_shape[0])
118    output_row_size = output_shape.num_elements() // batch_size
119  else:
120    output_shape = array_ops.shape(output)
121    batch_size = output_shape[0]
122    output_row_size = array_ops.size(output) // batch_size
123  inp_shape = array_ops.shape(inp)
124  # Flatten output to 2-D.
125  with ops.control_dependencies(
126      [check_ops.assert_equal(batch_size, inp_shape[0])]):
127    output = array_ops.reshape(output, [batch_size, output_row_size])
128
129  def loop_fn(i):
130    y = array_ops.gather(output, i, axis=1)
131    return gradient_ops.gradients(y, inp)[0]
132
133  if use_pfor:
134    pfor_output = control_flow_ops.pfor(loop_fn, output_row_size,
135                                        parallel_iterations=parallel_iterations)
136  else:
137    pfor_output = control_flow_ops.for_loop(
138        loop_fn, output.dtype,
139        output_row_size,
140        parallel_iterations=parallel_iterations)
141  if pfor_output is None:
142    return None
143  pfor_output = array_ops.reshape(pfor_output,
144                                  [output_row_size, batch_size, -1])
145  output = array_ops.transpose(pfor_output, [1, 0, 2])
146  new_shape = array_ops.concat([output_shape, inp_shape[1:]], axis=0)
147  return array_ops.reshape(output, new_shape)
148