• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in linalg_ops.py.
16
17Useful reference for derivative formulas is
18An extended collection of matrix derivative results for forward and reverse
19mode algorithmic differentiation by Mike Giles:
20http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
21
22A detailed derivation of formulas for backpropagating through spectral layers
23(SVD and Eig) by Ionescu, Vantzos & Sminchisescu:
24https://arxiv.org/pdf/1509.07838v4.pdf
25"""
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import linalg_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops.linalg import linalg_impl as _linalg
36
37
38@ops.RegisterGradient("MatrixInverse")
39def _MatrixInverseGrad(op, grad):
40  """Gradient for MatrixInverse."""
41  ainv = op.outputs[0]
42  return -math_ops.matmul(
43      ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True)
44
45
46@ops.RegisterGradient("MatrixDeterminant")
47def _MatrixDeterminantGrad(op, grad):
48  """Gradient for MatrixDeterminant."""
49  a = op.inputs[0]
50  c = op.outputs[0]
51  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
52  multipliers = array_ops.reshape(grad * c,
53                                  array_ops.concat([array_ops.shape(c), [1, 1]],
54                                                   0))
55  return multipliers * a_adj_inv
56
57
58@ops.RegisterGradient("MatrixSquareRoot")
59def _MatrixSquareRootGrad(op, grad):
60  """Gradient for MatrixSquareRoot."""
61
62  # Let A be an m x m square matrix (or batch of matrices)
63  # Let R = sqrtm(A)
64  # By definition, A = RR
65  # Take the differential: dA = d(RR) = RdR + dRR
66  # Solve the resulting Sylvester equation for dR
67
68  # Used to find Kronecker products within the Sylvester equation
69  def _KroneckerProduct(b1, b2):
70    """Computes the Kronecker product of two batches of square matrices"""
71    b1_shape = array_ops.shape(b1)
72    b2_shape = array_ops.shape(b2)
73    b1_order = b1_shape[-1]
74    b2_order = b2_shape[-1]
75
76    shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
77    shape_slice = array_ops.slice(b1_shape, [0],
78                                  shape_slice_size)  # Same for both batches
79    b1_reshape_shape = array_ops.concat(
80        [shape_slice, [b1_order], [1], [b1_order], [1]], 0)
81    b2_reshape_shape = array_ops.concat(
82        [shape_slice, [1], [b2_order], [1], [b2_order]], 0)
83
84    b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
85    b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
86
87    order_prod = b1_order * b2_order
88    kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
89    return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
90
91  sqrtm = op.outputs[0]  # R
92  shape = array_ops.shape(sqrtm)
93  order = shape[-1]  # m
94  matrix_count = math_ops.reduce_prod(shape[0:-2])
95
96  # Get batch of m x m identity matrices
97  eye = linalg_ops.eye(order, dtype=sqrtm.dtype)  # m x m identity matrix
98  eye_flat = array_ops.reshape(eye, [-1])
99  eye_tiled = array_ops.tile(eye_flat, [matrix_count])
100  eye_batch = array_ops.reshape(eye_tiled, shape)
101
102  # The transpose of R is taken in the k1 term instead of k2 in
103  # order to prevent redundant transposition of R (i.e. (R')' = R)
104  sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
105  k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
106  k2 = _KroneckerProduct(sqrtm, eye_batch)
107  ksum = math_ops.add(k1, k2)
108
109  # Vectorize dA
110  shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
111  shape_slice = array_ops.slice(shape, [0], shape_slice_size)
112  shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
113  vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
114
115  # Solve for vec(dR)
116  vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
117
118  # Solve for dR by inverse vectorizing vec(dR)
119  dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
120  return array_ops.matrix_transpose(dsqrtm_transpose)
121
122
123@ops.RegisterGradient("LogMatrixDeterminant")
124def _LogMatrixDeterminantGrad(op, _, grad_b):
125  """Gradient for LogMatrixDeterminant."""
126  a = op.inputs[0]
127  c = op.outputs[1]
128  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
129  multipliers = array_ops.reshape(
130      grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
131  return multipliers * a_adj_inv
132
133
134@ops.RegisterGradient("Cholesky")
135def _CholeskyGrad(op, grad):
136  """Gradient for Cholesky."""
137
138  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
139  l = op.outputs[0]
140  num_rows = array_ops.shape(l)[-1]
141  batch_shape = array_ops.shape(l)[:-2]
142  l_inverse = linalg_ops.matrix_triangular_solve(l,
143                                                 linalg_ops.eye(
144                                                     num_rows,
145                                                     batch_shape=batch_shape,
146                                                     dtype=l.dtype))
147
148  middle = math_ops.matmul(l, grad, adjoint_a=True)
149  middle = array_ops.matrix_set_diag(middle,
150                                     0.5 * array_ops.matrix_diag_part(middle))
151  middle = array_ops.matrix_band_part(middle, -1, 0)
152
153  grad_a = math_ops.matmul(
154      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
155
156  grad_a += _linalg.adjoint(grad_a)
157  return grad_a * 0.5
158
159
160@ops.RegisterGradient("Qr")
161def _QrGrad(op, dq, dr):
162  """Gradient for Qr."""
163  q, r = op.outputs
164  if q.dtype.is_complex:
165    raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
166  if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
167      r.shape.as_list()[-1] is None):
168    raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
169  if r.shape.dims[-2].value != r.shape.dims[-1].value:
170    raise NotImplementedError("QrGrad not implemented when ncols > nrows "
171                              "or full_matrices is true and ncols != nrows.")
172
173  qdq = math_ops.matmul(q, dq, adjoint_a=True)
174  qdq_ = qdq - _linalg.adjoint(qdq)
175  rdr = math_ops.matmul(r, dr, adjoint_b=True)
176  rdr_ = rdr - _linalg.adjoint(rdr)
177  tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
178
179  def _TriangularSolve(x, r):
180    """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
181    return _linalg.adjoint(
182        linalg_ops.matrix_triangular_solve(
183            r, _linalg.adjoint(x), lower=False, adjoint=False))
184
185  grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
186  grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
187  return grad_a + grad_b
188
189
190@ops.RegisterGradient("MatrixSolve")
191def _MatrixSolveGrad(op, grad):
192  """Gradient for MatrixSolve."""
193  a = op.inputs[0]
194  adjoint_a = op.get_attr("adjoint")
195  c = op.outputs[0]
196  grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
197  if adjoint_a:
198    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
199  else:
200    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
201  return (grad_a, grad_b)
202
203
204@ops.RegisterGradient("MatrixSolveLs")
205def _MatrixSolveLsGrad(op, grad):
206  """Gradients for MatrixSolveLs."""
207
208  # TODO(rmlarsen): The implementation could be more efficient:
209  #   a) Output the Cholesky factorization from forward op instead of
210  #      recomputing it here.
211  #   b) Implement a symmetric rank-k update op instead of computing
212  #      x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
213
214  def _Overdetermined(op, grad):
215    """Gradients for the overdetermined case of MatrixSolveLs.
216
217    This is the backprop for the solution to the normal equations of the first
218    kind:
219       X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
220    which solve the least squares problem
221       min ||A * X - B||_F^2 + lambda ||X||_F^2.
222    """
223    a = op.inputs[0]
224    b = op.inputs[1]
225    x = op.outputs[0]
226    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
227    # pylint: disable=protected-access
228    chol = linalg_ops._RegularizedGramianCholesky(
229        a, l2_regularizer=l2_regularizer, first_kind=True)
230    # pylint: enable=protected-access
231    # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
232    z = linalg_ops.cholesky_solve(chol, grad)
233    xzt = math_ops.matmul(x, z, adjoint_b=True)
234    zx_sym = xzt + array_ops.matrix_transpose(xzt)
235    grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)
236    grad_b = math_ops.matmul(a, z)
237    return (grad_a, grad_b, None)
238
239  def _Underdetermined(op, grad):
240    """Gradients for the underdetermined case of MatrixSolveLs.
241
242    This is the backprop for the solution to the normal equations of the second
243    kind:
244      X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
245    that (for lambda=0) solve the least squares problem
246      min ||X||_F subject to A*X = B.
247    """
248    a = op.inputs[0]
249    b = op.inputs[1]
250    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
251    # pylint: disable=protected-access
252    chol = linalg_ops._RegularizedGramianCholesky(
253        a, l2_regularizer=l2_regularizer, first_kind=False)
254    # pylint: enable=protected-access
255    grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
256    # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
257    tmp = linalg_ops.cholesky_solve(chol, b)
258    a1 = math_ops.matmul(tmp, a, adjoint_a=True)
259    a1 = -math_ops.matmul(grad_b, a1)
260    a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
261    a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
262    grad_a = a1 + a2
263    return (grad_a, grad_b, None)
264
265  fast = op.get_attr("fast")
266  if fast is False:
267    raise ValueError("Gradient not defined for fast=False")
268  matrix_shape = op.inputs[0].get_shape()[-2:]
269  if matrix_shape.is_fully_defined():
270    if matrix_shape[-2] >= matrix_shape[-1]:
271      return _Overdetermined(op, grad)
272    else:
273      return _Underdetermined(op, grad)
274  else:
275    # We have to defer determining the shape to runtime and use
276    # conditional execution of the appropriate graph.
277    matrix_shape = array_ops.shape(op.inputs[0])[-2:]
278    return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
279                                 lambda: _Overdetermined(op, grad),
280                                 lambda: _Underdetermined(op, grad))
281
282
283@ops.RegisterGradient("MatrixTriangularSolve")
284def _MatrixTriangularSolveGrad(op, grad):
285  """Gradient for MatrixTriangularSolve."""
286  a = op.inputs[0]
287  adjoint_a = op.get_attr("adjoint")
288  lower_a = op.get_attr("lower")
289  c = op.outputs[0]
290  grad_b = linalg_ops.matrix_triangular_solve(
291      a, grad, lower=lower_a, adjoint=not adjoint_a)
292  if adjoint_a:
293    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
294  else:
295    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
296  if lower_a:
297    grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
298  else:
299    grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
300  return (grad_a, grad_b)
301
302
303@ops.RegisterGradient("SelfAdjointEigV2")
304def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
305  """Gradient for SelfAdjointEigV2."""
306  e = op.outputs[0]
307  compute_v = op.get_attr("compute_v")
308  # a = op.inputs[0], which satisfies
309  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
310  with ops.control_dependencies([grad_e, grad_v]):
311    if compute_v:
312      v = op.outputs[1]
313      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
314      # Notice that because of the term involving f, the gradient becomes
315      # infinite (or NaN in practice) when eigenvalues are not unique.
316      # Mathematically this should not be surprising, since for (k-fold)
317      # degenerate eigenvalues, the corresponding eigenvectors are only defined
318      # up to arbitrary rotation in a (k-dimensional) subspace.
319      f = array_ops.matrix_set_diag(
320          math_ops.reciprocal(
321              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
322          array_ops.zeros_like(e))
323      grad_a = math_ops.matmul(
324          v,
325          math_ops.matmul(
326              array_ops.matrix_diag(grad_e) +
327              f * math_ops.matmul(v, grad_v, adjoint_a=True),
328              v,
329              adjoint_b=True))
330    else:
331      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
332      grad_a = math_ops.matmul(v,
333                               math_ops.matmul(
334                                   array_ops.matrix_diag(grad_e),
335                                   v,
336                                   adjoint_b=True))
337    # The forward op only depends on the lower triangular part of a, so here we
338    # symmetrize and take the lower triangle
339    grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
340    grad_a = array_ops.matrix_set_diag(grad_a,
341                                       0.5 * array_ops.matrix_diag_part(grad_a))
342    return grad_a
343
344
345@ops.RegisterGradient("Svd")
346def _SvdGrad(op, grad_s, grad_u, grad_v):
347  """Gradient for the singular value decomposition."""
348
349  # The derivation for the compute_uv=False case, and most of
350  # the derivation for the full_matrices=True case, are in
351  # Giles' paper (see reference at top of file).  A derivation for
352  # the full_matrices=False case is available at
353  # https://j-towns.github.io/papers/svd-derivative.pdf
354  a = op.inputs[0]
355  a_shape = a.get_shape().with_rank_at_least(2)
356  grad_s_mat = array_ops.matrix_diag(grad_s)
357
358  if not op.get_attr("compute_uv"):
359    s, u, v = linalg_ops.svd(a, compute_uv=True)
360    grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
361    grad_a.set_shape(a_shape)
362    return grad_a
363
364  full_matrices = op.get_attr("full_matrices")
365
366  # TODO(rmlarsen): Make this work with complex types.
367  if a.dtype.is_complex:
368    raise NotImplementedError(
369        "SVD gradient is not implemented for complex types and "
370        "compute_uv=True.")
371  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
372  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
373  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
374  n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
375  batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
376      grad_v_shape[:-2])
377  a_shape = batch_shape.concatenate([m, n])
378
379  m = a_shape.dims[-2].value
380  n = a_shape.dims[-1].value
381  # TODO(rmlarsen): Make this work with placeholders.
382  if m is None or n is None:
383    raise NotImplementedError(
384        "SVD gradient has not been implemented for input with unknown "
385        "inner matrix shape.")
386
387  s = op.outputs[0]
388  u = op.outputs[1]
389  v = op.outputs[2]
390
391  use_adjoint = False
392  if m > n:
393    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
394    # Hermitian transpose of the gradient at the end.
395    use_adjoint = True
396    m, n = n, m
397    u, v = v, u
398    grad_u, grad_v = grad_v, grad_u
399
400  with ops.control_dependencies([grad_s, grad_u, grad_v]):
401    if full_matrices and abs(m - n) > 1:
402      raise NotImplementedError(
403          "svd gradient is not implemented for abs(m - n) > 1 "
404          "when full_matrices is True")
405    s_mat = array_ops.matrix_diag(s)
406    s2 = math_ops.square(s)
407
408    # NOTICE: Because of the term involving f, the gradient becomes
409    # infinite (or NaN in practice) when singular values are not unique.
410    # Mathematically this should not be surprising, since for (k-fold)
411    # degenerate singular values, the corresponding singular vectors are
412    # only defined up a (k-dimensional) subspace. In practice, this can
413    # lead to numerical instability when singular values are close but not
414    # exactly equal.
415    # Also, even with distinct singular values, the diagonal of f can have Inf
416    # values before setting to zero, which hurt when differentiating through
417    # this op. To avoid that, we add eye to the matrix before taking
418    # the reciprocal.
419    s_shape = array_ops.shape(s)
420    eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=s.dtype)
421    f = array_ops.matrix_set_diag(
422        math_ops.reciprocal(
423            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) +
424            eye), array_ops.zeros_like(s))
425    s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
426
427    v1 = v[..., :, :m]
428    grad_v1 = grad_v[..., :, :m]
429
430    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
431    v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
432
433    f_u = f * u_gu
434    f_v = f * v_gv
435
436    term1_nouv = (
437        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
438        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
439
440    term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
441
442    if m == n:
443      grad_a_before_transpose = term1
444    else:
445      gv1t = array_ops.matrix_transpose(grad_v1)
446      gv1t_v1 = math_ops.matmul(gv1t, v1)
447      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
448
449      if full_matrices:
450        v2 = v[..., :, m:n]
451        grad_v2 = grad_v[..., :, m:n]
452
453        v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
454        term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
455
456      u_s_inv = math_ops.matmul(u, s_inv_mat)
457      term2 = math_ops.matmul(u_s_inv, term2_nous)
458
459      grad_a_before_transpose = term1 + term2
460
461    if use_adjoint:
462      grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
463    else:
464      grad_a = grad_a_before_transpose
465
466    grad_a.set_shape(a_shape)
467    return grad_a
468