• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in linalg_ops.py.
16
17Useful reference for derivative formulas is (Mike Giles, 2008).
18
19Ionescu et al. (2015) provide a detailed derivation of formulas for
20backpropagating through spectral layers (SVD and Eig).
21
22References:
23  An extended collection of matrix derivative results for
24  forward and reverse mode automatic differentiation:
25    [Mike Giles, 2008]
26    (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124)
27    ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf))
28  Matrix Backpropagation for Deep Networks with Structured Layers
29    [Ionescu et al., 2015]
30    (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html)
31    ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf))
32  Training Deep Networks with Structured Layers by Matrix Backpropagation:
33    [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838)
34    ([pdf](https://arxiv.org/pdf/1509.07838.pdf))
35"""
36from __future__ import absolute_import
37from __future__ import division
38from __future__ import print_function
39
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gen_linalg_ops
45from tensorflow.python.ops import linalg_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops.linalg import linalg_impl as _linalg
48
49
50@ops.RegisterGradient("MatrixInverse")
51def _MatrixInverseGrad(op, grad):
52  """Gradient for MatrixInverse."""
53  ainv = op.outputs[0]
54  return -math_ops.matmul(
55      ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True)
56
57
58@ops.RegisterGradient("Einsum")
59def _EinsumGrad(op, grad):
60  """Gradient for Einsum."""
61  ellipsis = "..."
62
63  def _GetAxisFromLabel(subscripts, label):
64    """Returns the axis (possibly negative) corresponding to a label.
65
66    Returns the axis index of the axis label if it is before an ellipsis (or if
67    the ellipsis is not present), and the negative index if it occurs after the
68    ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
69
70    For multiple occurrences, returns the leftmost one. If not found, returns
71    None.
72
73    Args:
74      subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
75      label: The single character axis label.
76    """
77    splits = subscripts.split(ellipsis)
78    index = splits[0].find(label)
79    if index != -1:
80      return index
81    if len(splits) < 2:
82      return None
83    index = splits[1].find(label)
84    if index != -1:
85      return index - len(splits[1])
86    return None
87
88  def _GetBcastSubshape(subscripts):
89    """Returns a tuple denoting the slice mapping to ellipsis.
90
91    For a given subscript, returns a tuple (start, end) denoting the start
92    axis index and the (negative) end axis index respectively. For any input
93    Tensor `x` described by the subscript, `x[start:end]` would be the slice
94    represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
95
96    If ellipsis is not present in `subscripts`, returns `(0, 0)`.
97
98    Args:
99      subscripts: A string denoting the einsum subscript.
100    """
101    start = subscripts.find(ellipsis)
102    if start == -1:
103      return 0, 0
104    remaining = len(subscripts) - (start + len(ellipsis))
105    end = -remaining if remaining > 0 else None
106    return start, end
107
108  def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts):
109    """Returns reduced subscripts and their corresponding dimensions and axes.
110
111    Given a set of axis labels, returns their concatenated subscript, their
112    corresponding dimensions from input_shape, and their corresponding axes.
113    Note that the concatenated subscript `reduced_subs` may have axis labels
114    from `reduced_label_set` in any order. For example, for the reduced label
115    set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
116    subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
117
118    Args:
119      reduced_label_set: Set of axis labels which appear in `subscripts`.
120      input_shape: A `Tensor` representing the shape of the einsum operand
121        corresponding to `subscripts`.
122      subscripts: A string denoting the einsum subscript.
123
124    Returns:
125      reduced_subs: Subscripts formed by a concatenation of labels in
126        `reduced_label_set`.
127      reduced_dims: Dimensions from `input_shape` corresponding to each label
128        in `reduced_subs`.
129      reduced_axes: Axes described by `subscripts` corresponding to each label
130        in `reduced_subs`. If there are multiple occurrences in `subscripts`,
131        we consider only the leftmost one.
132
133    """
134    # Concatenate the sequence of reduced axis labels.
135    reduced_subs = "".join(list(reduced_label_set))
136    # Get the axis (may be positive, negative or zero) for each of the reduced
137    # labels. If the same label appears multiple times, get the left-most axis.
138    reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs]
139    # Get the corresponding dimensions for each reduced axis.
140    reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes])
141    return reduced_subs, reduced_dims, reduced_axes
142
143  def _GetGradReduced(output_grad, output_subs, input_subs, input_shape,
144                      reduced_label_set):
145    """Returns the gradient wrt input for a unary einsum with reductions.
146
147    Args:
148      output_grad: The gradient wrt the output of a unary einsum operation.
149      output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
150      input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
151      input_shape: A `Tensor` representing the shape of the input operand.
152      reduced_label_set: The set of axis labels appearing in `input_subs` but
153        not in `output_subs`.
154    """
155    # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
156    # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
157    # subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
158    reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts(
159        reduced_label_set, input_shape, input_subs)
160    # Whether either the input or the output subscripts have a repeated label.
161    # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
162    has_repeated_labels = (
163        len(set(input_subs)) + len(set(output_subs)) <
164        len(input_subs) + len(output_subs))
165    # Compute the input subscripts without the reduced axis labels, e.g. "aac"
166    # for the equation "aabbcd->ca".
167    input_subs_without_reduced_labels = "".join(
168        [s for s in input_subs if s not in reduced_label_set])
169
170    # The gradient wrt the input for the equation "abc->ac" (or, equivalently
171    # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
172    # along axis 1, where label 'b' represents a dimension of size N.
173    #
174    # If we're not dealing with repeated labels, and the non-reduced labels
175    # doesn't need to be transposed, then just tiling is enough and there is no
176    # need to call another einsum. For example, tiling is sufficient for
177    # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
178    # "abc->ca" (transpose), we'd need another einsum operation after tiling.
179    if (not has_repeated_labels and
180        input_subs_without_reduced_labels == output_subs):
181      # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
182      # for the equation "abcd->ac" with input shape [2,5,3,4], we get the
183      # reduced shape [2,1,3,1].
184      reduced_shape = math_ops.reduced_shape(
185          input_shape, ops.convert_to_tensor(reduced_axes))
186      # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
187      # the shape [2,5,3,4] results in the gradient wrt "abcd".
188      return array_ops.broadcast_to(
189          array_ops.reshape(output_grad, reduced_shape), input_shape)
190
191    # If we *do* have traces or transpose operations, then prepend the extra
192    # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
193    # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
194    #
195    # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
196    # This is the shape of the intermediate "bdca".
197    grad_shape_with_reduced_labels = array_ops.concat(
198        [reduced_dims, array_ops.shape(output_grad)], axis=0)
199    # Obtain the output shape of the reduction-only equation "bdca->ca" as if
200    # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we
201    # just have to prepend that many 1s to the output shape.
202    reduced_shape = (
203        array_ops.concat([
204            array_ops.ones(len(reduced_label_set), dtype=dtypes.int32),
205            array_ops.shape(output_grad)
206        ],
207                         axis=0))
208    # Compute the VJP for the intermediate (viz. "bdca->ca") for which
209    # broadcasting is sufficient.
210    broadcasted_grad = array_ops.broadcast_to(
211        array_ops.reshape(output_grad, reduced_shape),
212        grad_shape_with_reduced_labels)
213    # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use
214    # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd")
215    # since the output axis labels now appear in the input subscripts.
216    return gen_linalg_ops.einsum([broadcasted_grad],
217                                 "{}->{}".format(reduced_subs + output_subs,
218                                                 input_subs))
219
220  def _GetGradWrt(output_grad, other_operand, input_shape, input_subs,
221                  other_subs, output_subs):
222    """Returns the gradient wrt an input operand for a binary einsum.
223
224    This function does not handle (un)broadcasting. This must be done separately
225    on the returned gradient.
226
227    Args:
228      output_grad: The gradient wrt the output of a binary einsum operation.
229      other_operand: The complementary `Tensor` operand i.e. which is not the
230        input operand.
231      input_shape: A `Tensor` representing the shape of input operand.
232      input_subs: The subscripts of the input operand.
233      other_subs: The subscripts of the complementary operand.
234      output_subs: The output subscripts.
235    """
236    # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
237    #   where the equation involves only Tensor contractions, generalized traces
238    #   and transposes, the input gradients are given by the vector-jacobian
239    #   products (VJPs):
240    #
241    #     grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
242    #     grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
243    #
244    #   where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
245    #   x and y and grad_wrt_z is the given gradient with respect to output z.
246    #
247    # Proof: For unary einsum equations involving only transpose ("ij->ji") and
248    #   traces ("ii->i"), the linear mapping's Jacobian at input x is given
249    #   by the function itself. We can verify that the linear map given by the
250    #   VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
251    #   where the latter represents 'un-tracing', or filling the diagonal with
252    #   the input axis and non-diagonal entries are zeros.
253    #        Furthermore, recall that matrix multiplication, which is
254    #   represented by the equation "ab,bc->ac", has its VJPs given by the
255    #   einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
256    #   https://math.stackexchange.com/a/2755680). Combined with transposes and
257    #   traces we can rewrite Tensor contractions as regular matrix
258    #   multiplication. Since each of these operations have their VJPs described
259    #   by einsums of the required pattern, the result follows.
260    #
261    # Accordingly, einsum operations except for those with reductions, e.g.
262    # "abc,cd->ad" have their VJPs defined by:
263    #   "{output_subs},{other_subs}->{input_subs}".
264    #
265    # But if there is a reduction, this would lead to the equation "ad,cd->abc"
266    # which is invalid because the reduced axis label 'b' is present in the
267    # output but not in any of the inputs. Therefore, we compute the VJP in two
268    # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
269    # "abc->ac" or, equivalently, reduce_sum(..., axis=1).
270    #
271    # Compute the set of input axis labels which doesn't appear in either the
272    # output subscripts or the other operand's subscript. E.g. the set {'b'} for
273    # the equation "abc,cd->ad".
274    reduced_label_set = set(input_subs).difference(
275        set(output_subs + other_subs + "."))
276    # Obtain the input subscripts with the reduced axis labels removed. E.g.
277    # "ac" in the above example.
278    left_subs = "".join(s for s in input_subs if s not in reduced_label_set)
279
280    # Compute the gradient wrt the input, without accounting for the operation
281    # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
282    grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand],
283                                         "{},{}->{}".format(
284                                             output_subs, other_subs,
285                                             left_subs))
286    # If the reduced_label_set is empty, then we already have the gradient
287    # wrt the input.
288    if not reduced_label_set:
289      return grad_reduced
290    # Otherwise, we currently have the gradient wrt the output of the reduction
291    # operation "abc->ac". Invoke the subroutine for the gradient for unary
292    # einsum with reductions.
293    return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape,
294                           reduced_label_set)
295
296  equation = op.get_attr("equation")
297  if isinstance(equation, bytes):
298    equation = equation.decode()
299  input_subs, output_subs = equation.split("->")
300
301  if len(op.inputs) == 1:
302    # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the
303    # input (VJP) is given by the reversed equation:
304    #   grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
305    # (See the justification in _GetGradWrt). This is valid unless there are
306    # reduced axis labels; i.e. axis labels appearing in the input but not in
307    # the output subscripts.
308    input_shape = array_ops.shape(op.inputs[0])
309    # Find the axis labels which appear only in the input.
310    reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis))
311    if not reduced_label_set:
312      # Return the einsum given by the reversed equation, since we don't have
313      # reduced axes.
314      return gen_linalg_ops.einsum([grad],
315                                   "{}->{}".format(output_subs, input_subs))
316    # We do have reduced axes, so we invoke the subroutine for reduced unary
317    # einsums.
318    return _GetGradReduced(grad, output_subs, input_subs, input_shape,
319                           reduced_label_set)
320
321  x_subs, y_subs = input_subs.split(",")
322  # Add ellipsis for broadcasted dimensions if any operand does not have it.
323  # This is because the equation "...ij,jk->ik" may be valid if the 0th input's
324  # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
325  # because only the output subscripts contain ellipsis.
326  if ellipsis in output_subs:
327    if ellipsis not in x_subs:
328      x_subs += ellipsis
329    if ellipsis not in y_subs:
330      y_subs += ellipsis
331
332  # Obtain the gradients wrt the inputs x and y, without taking into account
333  # the unbroadcasting.
334  x, y = op.inputs[0], op.inputs[1]
335  if grad.dtype.is_complex:
336    x = math_ops.conj(x)
337    y = math_ops.conj(y)
338
339  x_shape = array_ops.shape(x)
340  y_shape = array_ops.shape(y)
341  grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)
342  grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs)
343
344  if ellipsis not in output_subs:
345    # If no ellipsis in the output; then no need to unbroadcast.
346    return grad_x, grad_y
347
348  # Below we handle the case that broadcasting between x and y was necessary,
349  # with x and y having possibly different batch shapes.
350
351  # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c'
352  # and shape of rank 10; the range [3:-1] denotes the broadcasted axes.
353  bx_start, bx_end = _GetBcastSubshape(x_subs)
354  by_start, by_end = _GetBcastSubshape(y_subs)
355  # If the static batch shapes are equal, we don't need to unbroadcast.
356  x_shape_static = x.get_shape()
357  y_shape_static = y.get_shape()
358  if (x_shape_static.is_fully_defined() and
359      y_shape_static.is_fully_defined() and
360      x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]):
361    return grad_x, grad_y
362
363  # Sum the gradient across the broadcasted axes.
364  rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end],
365                                             y_shape[by_start:by_end])
366  grad_x = array_ops.reshape(
367      math_ops.reduce_sum(grad_x, bx_start + rx), x_shape)
368  grad_y = array_ops.reshape(
369      math_ops.reduce_sum(grad_y, by_start + ry), y_shape)
370  return grad_x, grad_y
371
372
373@ops.RegisterGradient("MatrixDeterminant")
374def _MatrixDeterminantGrad(op, grad):
375  """Gradient for MatrixDeterminant."""
376  a = op.inputs[0]
377  c = op.outputs[0]
378  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
379  multipliers = array_ops.reshape(grad * c,
380                                  array_ops.concat([array_ops.shape(c), [1, 1]],
381                                                   0))
382  return multipliers * a_adj_inv
383
384
385@ops.RegisterGradient("MatrixSquareRoot")
386def _MatrixSquareRootGrad(op, grad):
387  """Gradient for MatrixSquareRoot."""
388
389  # Let A be an m x m square matrix (or batch of matrices)
390  # Let R = sqrtm(A)
391  # By definition, A = RR
392  # Take the differential: dA = d(RR) = RdR + dRR
393  # Solve the resulting Sylvester equation for dR
394
395  # Used to find Kronecker products within the Sylvester equation
396  def _KroneckerProduct(b1, b2):
397    """Computes the Kronecker product of two batches of square matrices."""
398    b1_shape = array_ops.shape(b1)
399    b2_shape = array_ops.shape(b2)
400    b1_order = b1_shape[-1]
401    b2_order = b2_shape[-1]
402
403    shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
404    shape_slice = array_ops.slice(b1_shape, [0],
405                                  shape_slice_size)  # Same for both batches
406    b1_reshape_shape = array_ops.concat(
407        [shape_slice, [b1_order], [1], [b1_order], [1]], 0)
408    b2_reshape_shape = array_ops.concat(
409        [shape_slice, [1], [b2_order], [1], [b2_order]], 0)
410
411    b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
412    b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
413
414    order_prod = b1_order * b2_order
415    kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
416    return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
417
418  sqrtm = op.outputs[0]  # R
419  shape = array_ops.shape(sqrtm)
420  order = shape[-1]  # m
421  matrix_count = math_ops.reduce_prod(shape[0:-2])
422
423  # Get batch of m x m identity matrices
424  eye = linalg_ops.eye(order, dtype=sqrtm.dtype)  # m x m identity matrix
425  eye_flat = array_ops.reshape(eye, [-1])
426  eye_tiled = array_ops.tile(eye_flat, [matrix_count])
427  eye_batch = array_ops.reshape(eye_tiled, shape)
428
429  # The transpose of R is taken in the k1 term instead of k2 in
430  # order to prevent redundant transposition of R (i.e. (R')' = R)
431  sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
432  k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
433  k2 = _KroneckerProduct(sqrtm, eye_batch)
434  ksum = math_ops.add(k1, k2)
435
436  # Vectorize dA
437  shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
438  shape_slice = array_ops.slice(shape, [0], shape_slice_size)
439  shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
440  vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
441
442  # Solve for vec(dR)
443  vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
444
445  # Solve for dR by inverse vectorizing vec(dR)
446  dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
447  return array_ops.matrix_transpose(dsqrtm_transpose)
448
449
450@ops.RegisterGradient("LogMatrixDeterminant")
451def _LogMatrixDeterminantGrad(op, _, grad_b):
452  """Gradient for LogMatrixDeterminant."""
453  a = op.inputs[0]
454  c = op.outputs[1]
455  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
456  multipliers = array_ops.reshape(
457      grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
458  return multipliers * a_adj_inv
459
460
461@ops.RegisterGradient("Cholesky")
462def _CholeskyGrad(op, grad):
463  """Gradient for Cholesky."""
464
465  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
466  l = op.outputs[0]
467  num_rows = array_ops.shape(l)[-1]
468  batch_shape = array_ops.shape(l)[:-2]
469  l_inverse = linalg_ops.matrix_triangular_solve(l,
470                                                 linalg_ops.eye(
471                                                     num_rows,
472                                                     batch_shape=batch_shape,
473                                                     dtype=l.dtype))
474
475  middle = math_ops.matmul(l, grad, adjoint_a=True)
476  middle = array_ops.matrix_set_diag(middle,
477                                     0.5 * array_ops.matrix_diag_part(middle))
478  middle = array_ops.matrix_band_part(middle, -1, 0)
479
480  grad_a = math_ops.matmul(
481      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
482
483  grad_a += _linalg.adjoint(grad_a)
484  return grad_a * 0.5
485
486
487@ops.RegisterGradient("Qr")
488def _QrGrad(op, dq, dr):
489  """Gradient for Qr."""
490
491  # The methodology is explained in detail in https://arxiv.org/abs/2009.10071
492  # QR and LQ Decomposition Matrix Backpropagation Algorithms for
493  # Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation
494  q, r = op.outputs
495  if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
496      r.shape.as_list()[-1] is None):
497    raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
498  if (r.shape.dims[-2].value > r.shape.dims[-1].value and
499      q.shape.dims[-2].value == q.shape.dims[-1].value):
500    raise NotImplementedError("QrGrad not implemented when nrows > ncols "
501                              "and full_matrices is true.")
502
503  def _TriangularSolve(x, r):
504    """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
505    return _linalg.adjoint(
506        linalg_ops.matrix_triangular_solve(
507            r, _linalg.adjoint(x), lower=False, adjoint=False))
508
509  def _QrGradSquareAndDeepMatrices(q, r, dq, dr):
510    """Gradient for matrix orders num_rows >= num_cols
511    and full_matrices is false.
512    """
513    qdq = math_ops.matmul(q, dq, adjoint_a=True)
514    qdq_ = qdq - _linalg.adjoint(qdq)
515    rdr = math_ops.matmul(r, dr, adjoint_b=True)
516    rdr_ = rdr - _linalg.adjoint(rdr)
517    tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
518
519    grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
520    grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
521    ret = grad_a + grad_b
522
523    if q.dtype.is_complex:
524      # need to add a correction to the gradient formula for complex case
525      m = rdr - _linalg.adjoint(qdq)
526      eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
527      correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
528      ret = ret + _TriangularSolve(
529          math_ops.matmul(q, _linalg.adjoint(correction)), r)
530
531    return ret
532
533  num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
534
535  if num_rows >= num_cols:
536    return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
537
538  # Partition a = [x, y], r = [u, v] and reduce to the square case
539  a = op.inputs[0]
540  y = a[..., :, num_rows:]
541  u = r[..., :, :num_rows]
542  dv = dr[..., :, num_rows:]
543  du = dr[..., :, :num_rows]
544  dy = math_ops.matmul(q, dv)
545  dx = _QrGradSquareAndDeepMatrices(q, u,
546                                    dq + math_ops.matmul(y, dv, adjoint_b=True),
547                                    du)
548  return array_ops.concat([dx, dy], axis=-1)
549
550
551@ops.RegisterGradient("MatrixSolve")
552def _MatrixSolveGrad(op, grad):
553  """Gradient for MatrixSolve."""
554  a = op.inputs[0]
555  adjoint_a = op.get_attr("adjoint")
556  c = op.outputs[0]
557  grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
558  if adjoint_a:
559    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
560  else:
561    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
562  return (grad_a, grad_b)
563
564
565@ops.RegisterGradient("MatrixSolveLs")
566def _MatrixSolveLsGrad(op, grad):
567  """Gradients for MatrixSolveLs."""
568
569  # TODO(rmlarsen): The implementation could be more efficient:
570  #   a) Output the Cholesky factorization from forward op instead of
571  #      recomputing it here.
572  #   b) Implement a symmetric rank-k update op instead of computing
573  #      x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
574
575  def _Overdetermined(op, grad):
576    """Gradients for the overdetermined case of MatrixSolveLs.
577
578    This is the backprop for the solution to the normal equations of the first
579    kind:
580       X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
581    which solve the least squares problem
582       min ||A * X - B||_F^2 + lambda ||X||_F^2.
583    """
584    a = op.inputs[0]
585    b = op.inputs[1]
586    x = op.outputs[0]
587    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
588    # pylint: disable=protected-access
589    chol = linalg_ops._RegularizedGramianCholesky(
590        a, l2_regularizer=l2_regularizer, first_kind=True)
591    # pylint: enable=protected-access
592    # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
593    z = linalg_ops.cholesky_solve(chol, grad)
594    xzt = math_ops.matmul(x, z, adjoint_b=True)
595    zx_sym = xzt + array_ops.matrix_transpose(xzt)
596    grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)
597    grad_b = math_ops.matmul(a, z)
598    return (grad_a, grad_b, None)
599
600  def _Underdetermined(op, grad):
601    """Gradients for the underdetermined case of MatrixSolveLs.
602
603    This is the backprop for the solution to the normal equations of the second
604    kind:
605      X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
606    that (for lambda=0) solve the least squares problem
607      min ||X||_F subject to A*X = B.
608    """
609    a = op.inputs[0]
610    b = op.inputs[1]
611    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
612    # pylint: disable=protected-access
613    chol = linalg_ops._RegularizedGramianCholesky(
614        a, l2_regularizer=l2_regularizer, first_kind=False)
615    # pylint: enable=protected-access
616    grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
617    # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
618    tmp = linalg_ops.cholesky_solve(chol, b)
619    a1 = math_ops.matmul(tmp, a, adjoint_a=True)
620    a1 = -math_ops.matmul(grad_b, a1)
621    a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
622    a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
623    grad_a = a1 + a2
624    return (grad_a, grad_b, None)
625
626  fast = op.get_attr("fast")
627  if fast is False:
628    raise ValueError("Gradient not defined for fast=False")
629  matrix_shape = op.inputs[0].get_shape()[-2:]
630  if matrix_shape.is_fully_defined():
631    if matrix_shape[-2] >= matrix_shape[-1]:
632      return _Overdetermined(op, grad)
633    else:
634      return _Underdetermined(op, grad)
635  else:
636    # We have to defer determining the shape to runtime and use
637    # conditional execution of the appropriate graph.
638    matrix_shape = array_ops.shape(op.inputs[0])[-2:]
639    return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
640                                 lambda: _Overdetermined(op, grad),
641                                 lambda: _Underdetermined(op, grad))
642
643
644@ops.RegisterGradient("BandedTriangularSolve")
645def _BandedTriangularSolveGrad(op, grad):
646  """Gradient for BandedTriangularSolve."""
647  a = op.inputs[0]
648  b = op.inputs[1]
649  num_bands = array_ops.shape(a)[-2]
650  adjoint_a = op.get_attr("adjoint")
651  lower_a = op.get_attr("lower")
652  c = op.outputs[0]
653  grad_b = linalg_ops.banded_triangular_solve(
654      a, grad, lower=lower_a, adjoint=not adjoint_a)
655  if adjoint_a:
656    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
657  else:
658    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
659  if lower_a:
660    grad_a = array_ops.matrix_diag_part(
661        grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
662  else:
663    grad_a = array_ops.matrix_diag_part(
664        grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
665  # If the static batch shapes are equal, we don't need to unbroadcast.
666  if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
667      a.shape[:-2] == b.shape[:-2]):
668    return grad_a, grad_b
669  a_shape = array_ops.shape(a)
670  b_shape = array_ops.shape(b)
671  ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
672  grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
673  grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
674  return grad_a, grad_b
675
676
677@ops.RegisterGradient("MatrixTriangularSolve")
678def _MatrixTriangularSolveGrad(op, grad):
679  """Gradient for MatrixTriangularSolve."""
680  a = op.inputs[0]
681  b = op.inputs[1]
682  adjoint_a = op.get_attr("adjoint")
683  lower_a = op.get_attr("lower")
684  c = op.outputs[0]
685  grad_b = linalg_ops.matrix_triangular_solve(
686      a, grad, lower=lower_a, adjoint=not adjoint_a)
687  if adjoint_a:
688    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
689  else:
690    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
691  if lower_a:
692    grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
693  else:
694    grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
695  # If the static batch shapes are equal, we don't need to unbroadcast.
696  if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
697      a.shape[:-2] == b.shape[:-2]):
698    return grad_a, grad_b
699  a_shape = array_ops.shape(a)
700  b_shape = array_ops.shape(b)
701  ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
702  grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
703  grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
704  return grad_a, grad_b
705
706
707# To avoid nan in cases with degenerate eigenvalues or
708# degenerate/zero singular values in calculations of
709# f and s_inv_mat, we introduce a Lorentz broadening.
710def _SafeReciprocal(x, epsilon=1E-20):
711  return x * math_ops.reciprocal(x * x + epsilon)
712
713
714@ops.RegisterGradient("Eig")
715def _EigGrad(op, grad_e, grad_v):
716  """Gradient for Eig.
717
718  Based on eq. 4.77 from paper by
719  Christoph Boeddeker et al.
720  https://arxiv.org/abs/1701.00392
721  See also
722  "Computation of eigenvalue and eigenvector derivatives
723  for a general complex-valued eigensystem" by Nico van der Aa.
724  As for now only distinct eigenvalue case is considered.
725  """
726  e = op.outputs[0]
727  compute_v = op.get_attr("compute_v")
728  # a = op.inputs[0], which satisfies
729  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
730  with ops.control_dependencies([grad_e, grad_v]):
731    if compute_v:
732      v = op.outputs[1]
733      vt = _linalg.adjoint(v)
734      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
735      # Notice that because of the term involving f, the gradient becomes
736      # infinite (or NaN in practice) when eigenvalues are not unique.
737      # Mathematically this should not be surprising, since for (k-fold)
738      # degenerate eigenvalues, the corresponding eigenvectors are only defined
739      # up to arbitrary rotation in a (k-dimensional) subspace.
740      f = array_ops.matrix_set_diag(
741          _SafeReciprocal(
742              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
743          array_ops.zeros_like(e))
744      f = math_ops.conj(f)
745      vgv = math_ops.matmul(vt, grad_v)
746      mid = array_ops.matrix_diag(grad_e)
747      diag_grad_part = array_ops.matrix_diag(
748          array_ops.matrix_diag_part(
749              math_ops.cast(math_ops.real(vgv), vgv.dtype)))
750      mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
751      # vt is formally invertible as long as the original matrix is
752      # diagonalizable. However, in practice, vt may
753      # be ill-conditioned when matrix original matrix is close to
754      # non-diagonalizable one
755      grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
756    else:
757      _, v = linalg_ops.eig(op.inputs[0])
758      vt = _linalg.adjoint(v)
759      # vt is formally invertible as long as the original matrix is
760      # diagonalizable. However, in practice, vt may
761      # be ill-conditioned when matrix original matrix is close to
762      # non-diagonalizable one
763      grad_a = linalg_ops.matrix_solve(
764          vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
765    return math_ops.cast(grad_a, op.inputs[0].dtype)
766
767
768@ops.RegisterGradient("SelfAdjointEigV2")
769def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
770  """Gradient for SelfAdjointEigV2."""
771  e = op.outputs[0]
772  compute_v = op.get_attr("compute_v")
773  # a = op.inputs[0], which satisfies
774  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
775  with ops.control_dependencies([grad_e, grad_v]):
776    if compute_v:
777      v = op.outputs[1]
778      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
779      # Notice that because of the term involving f, the gradient becomes
780      # infinite (or NaN in practice) when eigenvalues are not unique.
781      # Mathematically this should not be surprising, since for (k-fold)
782      # degenerate eigenvalues, the corresponding eigenvectors are only defined
783      # up to arbitrary rotation in a (k-dimensional) subspace.
784      f = array_ops.matrix_set_diag(
785          _SafeReciprocal(
786              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
787          array_ops.zeros_like(e))
788      grad_a = math_ops.matmul(
789          v,
790          math_ops.matmul(
791              array_ops.matrix_diag(grad_e) +
792              f * math_ops.matmul(v, grad_v, adjoint_a=True),
793              v,
794              adjoint_b=True))
795    else:
796      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
797      grad_a = math_ops.matmul(v,
798                               math_ops.matmul(
799                                   array_ops.matrix_diag(grad_e),
800                                   v,
801                                   adjoint_b=True))
802    # The forward op only depends on the lower triangular part of a, so here we
803    # symmetrize and take the lower triangle
804    grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
805    grad_a = array_ops.matrix_set_diag(grad_a,
806                                       0.5 * array_ops.matrix_diag_part(grad_a))
807    return grad_a
808
809
810@ops.RegisterGradient("Svd")
811def _SvdGrad(op, grad_s, grad_u, grad_v):
812  """Gradient for the singular value decomposition."""
813
814  # The derivation for the compute_uv=False case, and most of
815  # the derivation for the full_matrices=True case, are in
816  # Giles' paper (see reference at top of file).  A derivation for
817  # the full_matrices=False case is available at
818  # https://j-towns.github.io/papers/svd-derivative.pdf
819  # The derivation for complex valued SVD can be found in
820  # https://re-ra.xyz/misc/complexsvd.pdf or
821  # https://giggleliu.github.io/2019/04/02/einsumbp.html
822  a = op.inputs[0]
823  a_shape = a.get_shape().with_rank_at_least(2)
824  grad_s = math_ops.cast(grad_s, a.dtype)
825  grad_s_mat = array_ops.matrix_diag(grad_s)
826
827  if not op.get_attr("compute_uv"):
828    s, u, v = linalg_ops.svd(a, compute_uv=True)
829    grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
830    grad_a.set_shape(a_shape)
831    return grad_a
832
833  full_matrices = op.get_attr("full_matrices")
834
835  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
836  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
837  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
838  n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
839  batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
840      grad_v_shape[:-2])
841  a_shape = batch_shape.concatenate([m, n])
842
843  m = a_shape.dims[-2].value
844  n = a_shape.dims[-1].value
845  # TODO(rmlarsen): Make this work with placeholders.
846  if m is None or n is None:
847    raise NotImplementedError(
848        "SVD gradient has not been implemented for input with unknown "
849        "inner matrix shape.")
850
851  s = op.outputs[0]
852  u = op.outputs[1]
853  v = op.outputs[2]
854  s = math_ops.cast(s, a.dtype)
855
856  use_adjoint = False
857  if m > n:
858    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
859    # Hermitian transpose of the gradient at the end.
860    use_adjoint = True
861    m, n = n, m
862    u, v = v, u
863    grad_u, grad_v = grad_v, grad_u
864
865  with ops.control_dependencies([grad_s, grad_u, grad_v]):
866    if full_matrices and abs(m - n) > 1:
867      raise NotImplementedError(
868          "svd gradient is not implemented for abs(m - n) > 1 "
869          "when full_matrices is True")
870    s_mat = array_ops.matrix_diag(s)
871    s2 = math_ops.square(s)
872
873    # NOTICE: Because of the term involving f, the gradient becomes
874    # infinite (or NaN in practice) when singular values are not unique.
875    # Mathematically this should not be surprising, since for (k-fold)
876    # degenerate singular values, the corresponding singular vectors are
877    # only defined up a (k-dimensional) subspace. In practice, this can
878    # lead to numerical instability when singular values are close but not
879    # exactly equal.
880
881    s_shape = array_ops.shape(s)
882    f = array_ops.matrix_set_diag(
883        _SafeReciprocal(
884            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
885        array_ops.zeros_like(s))
886    s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
887
888    v1 = v[..., :, :m]
889    grad_v1 = grad_v[..., :, :m]
890
891    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
892    v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
893
894    f_u = f * u_gu
895    f_v = f * v_gv
896
897    term1_nouv = (
898        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
899        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
900
901    term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
902
903    if m == n:
904      grad_a_before_transpose = term1
905    else:
906      gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
907      gv1t_v1 = math_ops.matmul(gv1t, v1)
908      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
909
910      if full_matrices:
911        v2 = v[..., :, m:n]
912        grad_v2 = grad_v[..., :, m:n]
913
914        v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
915        term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
916
917      u_s_inv = math_ops.matmul(u, s_inv_mat)
918      term2 = math_ops.matmul(u_s_inv, term2_nous)
919
920      grad_a_before_transpose = term1 + term2
921
922    if a.dtype.is_complex:
923      eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
924      l = eye * v_gv
925      term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
926      term3 = 1 / 2. * math_ops.matmul(
927          u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
928
929      grad_a_before_transpose += term3
930
931    if use_adjoint:
932      grad_a = array_ops.matrix_transpose(
933          grad_a_before_transpose, conjugate=True)
934    else:
935      grad_a = grad_a_before_transpose
936
937    grad_a.set_shape(a_shape)
938    return grad_a
939
940
941def _LeftShift(x):
942  """Shifts next-to-last dimension to the left, adding zero on the right."""
943  rank = array_ops.rank(x)
944  zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
945  pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0)
946  return array_ops.pad(x[..., 1:, :], pad)
947
948
949def _RightShift(x):
950  """Shifts next-to-last dimension to the right, adding zero on the left."""
951  rank = array_ops.rank(x)
952  zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
953  pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0)
954  return array_ops.pad(x[..., :-1, :], pad)
955
956
957@ops.RegisterGradient("TridiagonalMatMul")
958def _TridiagonalMatMulGrad(op, grad):
959  """Gradient for TridiagonalMatMul."""
960  superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True)
961  maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True)
962  subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True)
963  rhs_conj = math_ops.conj(op.inputs[3])
964
965  superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1)
966  maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1)
967  subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1)
968  rhs_grad = _RightShift(superdiag_conj * grad) + \
969      maindiag_conj * grad + _LeftShift(subdiag_conj * grad)
970
971  superdiag_grad = array_ops.expand_dims(superdiag_grad, -2)
972  maindiag_grad = array_ops.expand_dims(maindiag_grad, -2)
973  subdiag_grad = array_ops.expand_dims(subdiag_grad, -2)
974
975  return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad
976
977
978@ops.RegisterGradient("TridiagonalSolve")
979def _TridiagonalSolveGrad(op, grad):
980  """Gradient for TridiagonalSolveGrad."""
981  diags = op.inputs[0]
982  x = op.outputs[0]
983  partial_pivoting = op.get_attr("partial_pivoting")
984
985  # Transposing the matrix within tridiagonal_solve kernel by interchanging
986  # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
987  # paddings required by cusparse*gtsv routines.
988  # So constructing the transposed matrix in Python.
989  diags_transposed = _TransposeTridiagonalMatrix(diags)
990
991  grad_rhs = linalg_ops.tridiagonal_solve(diags_transposed, grad,
992                                          partial_pivoting=partial_pivoting)
993  grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x)
994  return grad_diags, grad_rhs
995
996
997def _TransposeTridiagonalMatrix(diags):
998  """Transposes a tridiagonal matrix.
999
1000  Args:
1001    diags: the diagonals of the input matrix in the compact form (see
1002      linalg_ops.tridiagonal_solve).
1003
1004  Returns:
1005    Diagonals of the transposed matrix in the compact form.
1006  """
1007
1008  diag = diags[..., 1, :]
1009
1010  if diags.shape.is_fully_defined():
1011    # For fully defined tensor we can concat with a tensor of zeros, which is
1012    # faster than using array_ops.pad().
1013    zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype)
1014    superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1)
1015    subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1)
1016  else:
1017    rank = array_ops.rank(diags)
1018    zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1019    superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])),
1020                                     axis=0)
1021    superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad)
1022    subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])),
1023                                   axis=0)
1024    subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad)
1025  return array_ops.stack([superdiag, diag, subdiag], axis=-2)
1026
1027
1028def _MatmulExtractingThreeDiagonals(x, y_tr):
1029  """Multiplies matrices and extracts three diagonals from the product.
1030
1031  With sizes M x K and K x M, this function takes O(MK) time and O(M) space,
1032  while using math_ops.matmul, and then extracting the diagonals would take
1033  O(M^2 K) time and O(M^2) space.
1034
1035  Args:
1036    x: first matrix
1037    y_tr: second matrix transposed
1038
1039  Returns:
1040    Diagonals of the product in compact format (see
1041    linalg_ops.tridiagonal_solve)
1042
1043  """
1044  diag = math_ops.reduce_sum(x * y_tr, axis=-1)
1045
1046  if y_tr.shape.is_fully_defined():
1047    zeros = array_ops.zeros(
1048        list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype)
1049    superdiag = math_ops.reduce_sum(
1050        x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1)
1051    subdiag = math_ops.reduce_sum(
1052        x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1)
1053  else:
1054    rank = array_ops.rank(y_tr)
1055    zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1056    superdiag_pad = array_ops.concat(
1057        (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0)
1058    superdiag = math_ops.reduce_sum(
1059        x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1)
1060    subdiag_pad = array_ops.concat(
1061        (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0)
1062    subdiag = math_ops.reduce_sum(
1063        x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1)
1064  return array_ops.stack([superdiag, diag, subdiag], axis=-2)
1065