• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in linalg_ops.py.
16
17Useful reference for derivative formulas is (Mike Giles, 2008).
18
19Ionescu et al. (2015) provide a detailed derivation of formulas for
20backpropagating through spectral layers (SVD and Eig).
21
22References:
23  An extended collection of matrix derivative results for
24  forward and reverse mode automatic differentiation:
25    [Mike Giles, 2008]
26    (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124)
27    ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf))
28  Matrix Backpropagation for Deep Networks with Structured Layers
29    [Ionescu et al., 2015]
30    (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html)
31    ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf))
32  Training Deep Networks with Structured Layers by Matrix Backpropagation:
33    [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838)
34    ([pdf](https://arxiv.org/pdf/1509.07838.pdf))
35"""
36from __future__ import absolute_import
37from __future__ import division
38from __future__ import print_function
39
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gen_linalg_ops
45from tensorflow.python.ops import linalg_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops.linalg import linalg_impl as _linalg
48
49
50@ops.RegisterGradient("MatrixInverse")
51def _MatrixInverseGrad(op, grad):
52  """Gradient for MatrixInverse."""
53  ainv = op.outputs[0]
54  return -math_ops.matmul(  # pylint: disable=invalid-unary-operand-type
55      ainv,
56      math_ops.matmul(grad, ainv, adjoint_b=True),
57      adjoint_a=True)
58
59
60@ops.RegisterGradient("Einsum")
61def _EinsumGrad(op, grad):
62  """Gradient for Einsum."""
63  ellipsis = "..."
64
65  def _GetAxisFromLabel(subscripts, label):
66    """Returns the axis (possibly negative) corresponding to a label.
67
68    Returns the axis index of the axis label if it is before an ellipsis (or if
69    the ellipsis is not present), and the negative index if it occurs after the
70    ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
71
72    For multiple occurrences, returns the leftmost one. If not found, returns
73    None.
74
75    Args:
76      subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
77      label: The single character axis label.
78    """
79    splits = subscripts.split(ellipsis)
80    index = splits[0].find(label)
81    if index != -1:
82      return index
83    if len(splits) < 2:
84      return None
85    index = splits[1].find(label)
86    if index != -1:
87      return index - len(splits[1])
88    return None
89
90  def _GetBcastSubshape(subscripts):
91    """Returns a tuple denoting the slice mapping to ellipsis.
92
93    For a given subscript, returns a tuple (start, end) denoting the start
94    axis index and the (negative) end axis index respectively. For any input
95    Tensor `x` described by the subscript, `x[start:end]` would be the slice
96    represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
97
98    If ellipsis is not present in `subscripts`, returns `(0, 0)`.
99
100    Args:
101      subscripts: A string denoting the einsum subscript.
102    """
103    start = subscripts.find(ellipsis)
104    if start == -1:
105      return 0, 0
106    remaining = len(subscripts) - (start + len(ellipsis))
107    end = -remaining if remaining > 0 else None
108    return start, end
109
110  def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts):
111    """Returns reduced subscripts and their corresponding dimensions and axes.
112
113    Given a set of axis labels, returns their concatenated subscript, their
114    corresponding dimensions from input_shape, and their corresponding axes.
115    Note that the concatenated subscript `reduced_subs` may have axis labels
116    from `reduced_label_set` in any order. For example, for the reduced label
117    set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
118    subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
119
120    Args:
121      reduced_label_set: Set of axis labels which appear in `subscripts`.
122      input_shape: A `Tensor` representing the shape of the einsum operand
123        corresponding to `subscripts`.
124      subscripts: A string denoting the einsum subscript.
125
126    Returns:
127      reduced_subs: Subscripts formed by a concatenation of labels in
128        `reduced_label_set`.
129      reduced_dims: Dimensions from `input_shape` corresponding to each label
130        in `reduced_subs`.
131      reduced_axes: Axes described by `subscripts` corresponding to each label
132        in `reduced_subs`. If there are multiple occurrences in `subscripts`,
133        we consider only the leftmost one.
134
135    """
136    # Concatenate the sequence of reduced axis labels.
137    reduced_subs = "".join(list(reduced_label_set))
138    # Get the axis (may be positive, negative or zero) for each of the reduced
139    # labels. If the same label appears multiple times, get the left-most axis.
140    reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs]
141    # Get the corresponding dimensions for each reduced axis.
142    reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes])
143    return reduced_subs, reduced_dims, reduced_axes
144
145  def _GetGradReduced(output_grad, output_subs, input_subs, input_shape,
146                      reduced_label_set):
147    """Returns the gradient wrt input for a unary einsum with reductions.
148
149    Args:
150      output_grad: The gradient wrt the output of a unary einsum operation.
151      output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
152      input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
153      input_shape: A `Tensor` representing the shape of the input operand.
154      reduced_label_set: The set of axis labels appearing in `input_subs` but
155        not in `output_subs`.
156    """
157    # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
158    # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
159    # subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
160    reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts(
161        reduced_label_set, input_shape, input_subs)
162    # Whether either the input or the output subscripts have a repeated label.
163    # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
164    has_repeated_labels = (
165        len(set(input_subs)) + len(set(output_subs)) <
166        len(input_subs) + len(output_subs))
167    # Compute the input subscripts without the reduced axis labels, e.g. "aac"
168    # for the equation "aabbcd->ca".
169    input_subs_without_reduced_labels = "".join(
170        [s for s in input_subs if s not in reduced_label_set])
171
172    # The gradient wrt the input for the equation "abc->ac" (or, equivalently
173    # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
174    # along axis 1, where label 'b' represents a dimension of size N.
175    #
176    # If we're not dealing with repeated labels, and the non-reduced labels
177    # doesn't need to be transposed, then just tiling is enough and there is no
178    # need to call another einsum. For example, tiling is sufficient for
179    # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
180    # "abc->ca" (transpose), we'd need another einsum operation after tiling.
181    if (not has_repeated_labels and
182        input_subs_without_reduced_labels == output_subs):
183      # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
184      # for the equation "abcd->ac" with input shape [2,5,3,4], we get the
185      # reduced shape [2,1,3,1].
186      reduced_shape = math_ops.reduced_shape(
187          input_shape, ops.convert_to_tensor(reduced_axes))
188      # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
189      # the shape [2,5,3,4] results in the gradient wrt "abcd".
190      return array_ops.broadcast_to(
191          array_ops.reshape(output_grad, reduced_shape), input_shape)
192
193    # If we *do* have traces or transpose operations, then prepend the extra
194    # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
195    # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
196    #
197    # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
198    # This is the shape of the intermediate "bdca".
199    grad_shape_with_reduced_labels = array_ops.concat(
200        [reduced_dims, array_ops.shape(output_grad)], axis=0)
201    # Obtain the output shape of the reduction-only equation "bdca->ca" as if
202    # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we
203    # just have to prepend that many 1s to the output shape.
204    reduced_shape = (
205        array_ops.concat([
206            array_ops.ones(len(reduced_label_set), dtype=dtypes.int32),
207            array_ops.shape(output_grad)
208        ],
209                         axis=0))
210    # Compute the VJP for the intermediate (viz. "bdca->ca") for which
211    # broadcasting is sufficient.
212    broadcasted_grad = array_ops.broadcast_to(
213        array_ops.reshape(output_grad, reduced_shape),
214        grad_shape_with_reduced_labels)
215    # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use
216    # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd")
217    # since the output axis labels now appear in the input subscripts.
218    return gen_linalg_ops.einsum([broadcasted_grad],
219                                 "{}->{}".format(reduced_subs + output_subs,
220                                                 input_subs))
221
222  def _GetGradWrt(output_grad, other_operand, input_shape, input_subs,
223                  other_subs, output_subs):
224    """Returns the gradient wrt an input operand for a binary einsum.
225
226    This function does not handle (un)broadcasting. This must be done separately
227    on the returned gradient.
228
229    Args:
230      output_grad: The gradient wrt the output of a binary einsum operation.
231      other_operand: The complementary `Tensor` operand i.e. which is not the
232        input operand.
233      input_shape: A `Tensor` representing the shape of input operand.
234      input_subs: The subscripts of the input operand.
235      other_subs: The subscripts of the complementary operand.
236      output_subs: The output subscripts.
237    """
238    # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
239    #   where the equation involves only Tensor contractions, generalized traces
240    #   and transposes, the input gradients are given by the vector-jacobian
241    #   products (VJPs):
242    #
243    #     grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
244    #     grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
245    #
246    #   where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
247    #   x and y and grad_wrt_z is the given gradient with respect to output z.
248    #
249    # Proof: For unary einsum equations involving only transpose ("ij->ji") and
250    #   traces ("ii->i"), the linear mapping's Jacobian at input x is given
251    #   by the function itself. We can verify that the linear map given by the
252    #   VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
253    #   where the latter represents 'un-tracing', or filling the diagonal with
254    #   the input axis and non-diagonal entries are zeros.
255    #        Furthermore, recall that matrix multiplication, which is
256    #   represented by the equation "ab,bc->ac", has its VJPs given by the
257    #   einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
258    #   https://math.stackexchange.com/a/2755680). Combined with transposes and
259    #   traces we can rewrite Tensor contractions as regular matrix
260    #   multiplication. Since each of these operations have their VJPs described
261    #   by einsums of the required pattern, the result follows.
262    #
263    # Accordingly, einsum operations except for those with reductions, e.g.
264    # "abc,cd->ad" have their VJPs defined by:
265    #   "{output_subs},{other_subs}->{input_subs}".
266    #
267    # But if there is a reduction, this would lead to the equation "ad,cd->abc"
268    # which is invalid because the reduced axis label 'b' is present in the
269    # output but not in any of the inputs. Therefore, we compute the VJP in two
270    # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
271    # "abc->ac" or, equivalently, reduce_sum(..., axis=1).
272    #
273    # Compute the set of input axis labels which doesn't appear in either the
274    # output subscripts or the other operand's subscript. E.g. the set {'b'} for
275    # the equation "abc,cd->ad".
276    reduced_label_set = set(input_subs).difference(
277        set(output_subs + other_subs + "."))
278    # Obtain the input subscripts with the reduced axis labels removed. E.g.
279    # "ac" in the above example.
280    left_subs = "".join(s for s in input_subs if s not in reduced_label_set)
281
282    # Compute the gradient wrt the input, without accounting for the operation
283    # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
284    grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand],
285                                         "{},{}->{}".format(
286                                             output_subs, other_subs,
287                                             left_subs))
288    # If the reduced_label_set is empty, then we already have the gradient
289    # wrt the input.
290    if not reduced_label_set:
291      return grad_reduced
292    # Otherwise, we currently have the gradient wrt the output of the reduction
293    # operation "abc->ac". Invoke the subroutine for the gradient for unary
294    # einsum with reductions.
295    return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape,
296                           reduced_label_set)
297
298  equation = op.get_attr("equation")
299  if isinstance(equation, bytes):
300    equation = equation.decode()
301  input_subs, output_subs = equation.split("->")
302
303  if len(op.inputs) == 1:
304    # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the
305    # input (VJP) is given by the reversed equation:
306    #   grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
307    # (See the justification in _GetGradWrt). This is valid unless there are
308    # reduced axis labels; i.e. axis labels appearing in the input but not in
309    # the output subscripts.
310    input_shape = array_ops.shape(op.inputs[0])
311    # Find the axis labels which appear only in the input.
312    reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis))
313    if not reduced_label_set:
314      # Return the einsum given by the reversed equation, since we don't have
315      # reduced axes.
316      return gen_linalg_ops.einsum([grad],
317                                   "{}->{}".format(output_subs, input_subs))
318    # We do have reduced axes, so we invoke the subroutine for reduced unary
319    # einsums.
320    return _GetGradReduced(grad, output_subs, input_subs, input_shape,
321                           reduced_label_set)
322
323  x_subs, y_subs = input_subs.split(",")
324  # Add ellipsis for broadcasted dimensions if any operand does not have it.
325  # This is because the equation "...ij,jk->ik" may be valid if the 0th input's
326  # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
327  # because only the output subscripts contain ellipsis.
328  if ellipsis in output_subs:
329    if ellipsis not in x_subs:
330      x_subs += ellipsis
331    if ellipsis not in y_subs:
332      y_subs += ellipsis
333
334  # Obtain the gradients wrt the inputs x and y, without taking into account
335  # the unbroadcasting.
336  x, y = op.inputs[0], op.inputs[1]
337  if grad.dtype.is_complex:
338    x = math_ops.conj(x)
339    y = math_ops.conj(y)
340
341  x_shape = array_ops.shape(x)
342  y_shape = array_ops.shape(y)
343  grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)
344  grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs)
345
346  if ellipsis not in output_subs:
347    # If no ellipsis in the output; then no need to unbroadcast.
348    return grad_x, grad_y
349
350  # Below we handle the case that broadcasting between x and y was necessary,
351  # with x and y having possibly different batch shapes.
352
353  # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c'
354  # and shape of rank 10; the range [3:-1] denotes the broadcasted axes.
355  bx_start, bx_end = _GetBcastSubshape(x_subs)
356  by_start, by_end = _GetBcastSubshape(y_subs)
357  # If the static batch shapes are equal, we don't need to unbroadcast.
358  x_shape_static = x.get_shape()
359  y_shape_static = y.get_shape()
360  if (x_shape_static.is_fully_defined() and
361      y_shape_static.is_fully_defined() and
362      x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]):
363    return grad_x, grad_y
364
365  # Sum the gradient across the broadcasted axes.
366  rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end],
367                                             y_shape[by_start:by_end])
368  grad_x = array_ops.reshape(
369      math_ops.reduce_sum(grad_x, bx_start + rx), x_shape)
370  grad_y = array_ops.reshape(
371      math_ops.reduce_sum(grad_y, by_start + ry), y_shape)
372  return grad_x, grad_y
373
374
375@ops.RegisterGradient("MatrixDeterminant")
376def _MatrixDeterminantGrad(op, grad):
377  """Gradient for MatrixDeterminant."""
378  a = op.inputs[0]
379  c = op.outputs[0]
380  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
381  multipliers = array_ops.reshape(grad * c,
382                                  array_ops.concat([array_ops.shape(c), [1, 1]],
383                                                   0))
384  return multipliers * a_adj_inv
385
386
387@ops.RegisterGradient("MatrixSquareRoot")
388def _MatrixSquareRootGrad(op, grad):
389  """Gradient for MatrixSquareRoot."""
390
391  # Let A be an m x m square matrix (or batch of matrices)
392  # Let R = sqrtm(A)
393  # By definition, A = RR
394  # Take the differential: dA = d(RR) = RdR + dRR
395  # Solve the resulting Sylvester equation for dR
396
397  # Used to find Kronecker products within the Sylvester equation
398  def _KroneckerProduct(b1, b2):
399    """Computes the Kronecker product of two batches of square matrices."""
400    b1_shape = array_ops.shape(b1)
401    b2_shape = array_ops.shape(b2)
402    b1_order = b1_shape[-1]
403    b2_order = b2_shape[-1]
404
405    shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
406    shape_slice = array_ops.slice(b1_shape, [0],
407                                  shape_slice_size)  # Same for both batches
408    b1_reshape_shape = array_ops.concat(
409        [shape_slice, [b1_order], [1], [b1_order], [1]], 0)
410    b2_reshape_shape = array_ops.concat(
411        [shape_slice, [1], [b2_order], [1], [b2_order]], 0)
412
413    b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
414    b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
415
416    order_prod = b1_order * b2_order
417    kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
418    return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
419
420  sqrtm = op.outputs[0]  # R
421  shape = array_ops.shape(sqrtm)
422  order = shape[-1]  # m
423  matrix_count = math_ops.reduce_prod(shape[0:-2])
424
425  # Get batch of m x m identity matrices
426  eye = linalg_ops.eye(order, dtype=sqrtm.dtype)  # m x m identity matrix
427  eye_flat = array_ops.reshape(eye, [-1])
428  eye_tiled = array_ops.tile(eye_flat, [matrix_count])
429  eye_batch = array_ops.reshape(eye_tiled, shape)
430
431  # The transpose of R is taken in the k1 term instead of k2 in
432  # order to prevent redundant transposition of R (i.e. (R')' = R)
433  sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
434  k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
435  k2 = _KroneckerProduct(sqrtm, eye_batch)
436  ksum = math_ops.add(k1, k2)
437
438  # Vectorize dA
439  shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
440  shape_slice = array_ops.slice(shape, [0], shape_slice_size)
441  shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
442  vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
443
444  # Solve for vec(dR)
445  vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
446
447  # Solve for dR by inverse vectorizing vec(dR)
448  dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
449  return array_ops.matrix_transpose(dsqrtm_transpose)
450
451
452@ops.RegisterGradient("LogMatrixDeterminant")
453def _LogMatrixDeterminantGrad(op, _, grad_b):
454  """Gradient for LogMatrixDeterminant."""
455  a = op.inputs[0]
456  c = op.outputs[1]
457  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
458  multipliers = array_ops.reshape(
459      grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
460  return multipliers * a_adj_inv
461
462
463@ops.RegisterGradient("Cholesky")
464def _CholeskyGrad(op, grad):
465  """Gradient for Cholesky."""
466
467  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
468  l = op.outputs[0]
469  num_rows = array_ops.shape(l)[-1]
470  batch_shape = array_ops.shape(l)[:-2]
471  l_inverse = linalg_ops.matrix_triangular_solve(l,
472                                                 linalg_ops.eye(
473                                                     num_rows,
474                                                     batch_shape=batch_shape,
475                                                     dtype=l.dtype))
476
477  middle = math_ops.matmul(l, grad, adjoint_a=True)
478  middle = array_ops.matrix_set_diag(middle,
479                                     0.5 * array_ops.matrix_diag_part(middle))
480  middle = array_ops.matrix_band_part(middle, -1, 0)
481
482  grad_a = math_ops.matmul(
483      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
484
485  grad_a += _linalg.adjoint(grad_a)
486  return grad_a * 0.5
487
488
489@ops.RegisterGradient("Qr")
490def _QrGrad(op, dq, dr):
491  """Gradient for Qr."""
492
493  # The methodology is explained in detail in https://arxiv.org/abs/2009.10071
494  # QR and LQ Decomposition Matrix Backpropagation Algorithms for
495  # Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation
496  q, r = op.outputs
497  if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
498      r.shape.as_list()[-1] is None):
499    raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
500  if (r.shape.dims[-2].value > r.shape.dims[-1].value and
501      q.shape.dims[-2].value == q.shape.dims[-1].value):
502    raise NotImplementedError("QrGrad not implemented when nrows > ncols "
503                              "and full_matrices is true.")
504
505  def _TriangularSolve(x, r):
506    """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
507    return _linalg.adjoint(
508        linalg_ops.matrix_triangular_solve(
509            r, _linalg.adjoint(x), lower=False, adjoint=False))
510
511  def _QrGradSquareAndDeepMatrices(q, r, dq, dr):
512    """Gradient for matrix orders num_rows >= num_cols
513    and full_matrices is false.
514    """
515    qdq = math_ops.matmul(q, dq, adjoint_a=True)
516    qdq_ = qdq - _linalg.adjoint(qdq)
517    rdr = math_ops.matmul(r, dr, adjoint_b=True)
518    rdr_ = rdr - _linalg.adjoint(rdr)
519    tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
520
521    grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
522    grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
523    ret = grad_a + grad_b
524
525    if q.dtype.is_complex:
526      # need to add a correction to the gradient formula for complex case
527      m = rdr - _linalg.adjoint(qdq)
528      eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
529      correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
530      ret = ret + _TriangularSolve(
531          math_ops.matmul(q, _linalg.adjoint(correction)), r)
532
533    return ret
534
535  num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
536
537  if num_rows >= num_cols:
538    return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
539
540  # Partition a = [x, y], r = [u, v] and reduce to the square case
541  a = op.inputs[0]
542  y = a[..., :, num_rows:]
543  u = r[..., :, :num_rows]
544  dv = dr[..., :, num_rows:]
545  du = dr[..., :, :num_rows]
546  dy = math_ops.matmul(q, dv)
547  dx = _QrGradSquareAndDeepMatrices(q, u,
548                                    dq + math_ops.matmul(y, dv, adjoint_b=True),
549                                    du)
550  return array_ops.concat([dx, dy], axis=-1)
551
552
553@ops.RegisterGradient("MatrixSolve")
554def _MatrixSolveGrad(op, grad):
555  """Gradient for MatrixSolve."""
556  a = op.inputs[0]
557  adjoint_a = op.get_attr("adjoint")
558  c = op.outputs[0]
559  grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
560  if adjoint_a:
561    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
562  else:
563    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
564  return (grad_a, grad_b)
565
566
567@ops.RegisterGradient("MatrixSolveLs")
568def _MatrixSolveLsGrad(op, grad):
569  """Gradients for MatrixSolveLs."""
570
571  # TODO(rmlarsen): The implementation could be more efficient:
572  #   a) Output the Cholesky factorization from forward op instead of
573  #      recomputing it here.
574  #   b) Implement a symmetric rank-k update op instead of computing
575  #      x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
576
577  def _Overdetermined(op, grad):
578    """Gradients for the overdetermined case of MatrixSolveLs.
579
580    This is the backprop for the solution to the normal equations of the first
581    kind:
582       X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
583    which solve the least squares problem
584       min ||A * X - B||_F^2 + lambda ||X||_F^2.
585    """
586    a = op.inputs[0]
587    b = op.inputs[1]
588    x = op.outputs[0]
589    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
590    # pylint: disable=protected-access
591    chol = linalg_ops._RegularizedGramianCholesky(
592        a, l2_regularizer=l2_regularizer, first_kind=True)
593    # pylint: enable=protected-access
594    # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
595    z = linalg_ops.cholesky_solve(chol, grad)
596    xzt = math_ops.matmul(x, z, adjoint_b=True)
597    zx_sym = xzt + array_ops.matrix_transpose(xzt)
598    grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
599    grad_b = math_ops.matmul(a, z)
600    return (grad_a, grad_b, None)
601
602  def _Underdetermined(op, grad):
603    """Gradients for the underdetermined case of MatrixSolveLs.
604
605    This is the backprop for the solution to the normal equations of the second
606    kind:
607      X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
608    that (for lambda=0) solve the least squares problem
609      min ||X||_F subject to A*X = B.
610    """
611    a = op.inputs[0]
612    b = op.inputs[1]
613    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
614    # pylint: disable=protected-access
615    chol = linalg_ops._RegularizedGramianCholesky(
616        a, l2_regularizer=l2_regularizer, first_kind=False)
617    # pylint: enable=protected-access
618    grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
619    # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
620    tmp = linalg_ops.cholesky_solve(chol, b)
621    a1 = math_ops.matmul(tmp, a, adjoint_a=True)
622    a1 = -math_ops.matmul(grad_b, a1)  # pylint: disable=invalid-unary-operand-type
623    a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
624    a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
625    grad_a = a1 + a2
626    return (grad_a, grad_b, None)
627
628  fast = op.get_attr("fast")
629  if fast is False:
630    raise ValueError("Gradient not defined for fast=False")
631  matrix_shape = op.inputs[0].get_shape()[-2:]
632  if matrix_shape.is_fully_defined():
633    if matrix_shape[-2] >= matrix_shape[-1]:
634      return _Overdetermined(op, grad)
635    else:
636      return _Underdetermined(op, grad)
637  else:
638    # We have to defer determining the shape to runtime and use
639    # conditional execution of the appropriate graph.
640    matrix_shape = array_ops.shape(op.inputs[0])[-2:]
641    return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
642                                 lambda: _Overdetermined(op, grad),
643                                 lambda: _Underdetermined(op, grad))
644
645
646@ops.RegisterGradient("BandedTriangularSolve")
647def _BandedTriangularSolveGrad(op, grad):
648  """Gradient for BandedTriangularSolve."""
649  a = op.inputs[0]
650  b = op.inputs[1]
651  num_bands = array_ops.shape(a)[-2]
652  adjoint_a = op.get_attr("adjoint")
653  lower_a = op.get_attr("lower")
654  c = op.outputs[0]
655  grad_b = linalg_ops.banded_triangular_solve(
656      a, grad, lower=lower_a, adjoint=not adjoint_a)
657  if adjoint_a:
658    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
659  else:
660    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
661  if lower_a:
662    grad_a = array_ops.matrix_diag_part(
663        grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
664  else:
665    grad_a = array_ops.matrix_diag_part(
666        grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
667  # If the static batch shapes are equal, we don't need to unbroadcast.
668  if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
669      a.shape[:-2] == b.shape[:-2]):
670    return grad_a, grad_b
671  a_shape = array_ops.shape(a)
672  b_shape = array_ops.shape(b)
673  ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
674  grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
675  grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
676  return grad_a, grad_b
677
678
679@ops.RegisterGradient("MatrixTriangularSolve")
680def _MatrixTriangularSolveGrad(op, grad):
681  """Gradient for MatrixTriangularSolve."""
682  a = op.inputs[0]
683  b = op.inputs[1]
684  adjoint_a = op.get_attr("adjoint")
685  lower_a = op.get_attr("lower")
686  c = op.outputs[0]
687  grad_b = linalg_ops.matrix_triangular_solve(
688      a, grad, lower=lower_a, adjoint=not adjoint_a)
689  if adjoint_a:
690    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
691  else:
692    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
693  if lower_a:
694    grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
695  else:
696    grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
697  # If the static batch shapes are equal, we don't need to unbroadcast.
698  if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
699      a.shape[:-2] == b.shape[:-2]):
700    return grad_a, grad_b
701  a_shape = array_ops.shape(a)
702  b_shape = array_ops.shape(b)
703  ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
704  grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
705  grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
706  return grad_a, grad_b
707
708
709# To avoid nan in cases with degenerate eigenvalues or
710# degenerate/zero singular values in calculations of
711# f and s_inv_mat, we introduce a Lorentz broadening.
712def _SafeReciprocal(x, epsilon=1E-20):
713  return x * math_ops.reciprocal(x * x + epsilon)
714
715
716@ops.RegisterGradient("Eig")
717def _EigGrad(op, grad_e, grad_v):
718  """Gradient for Eig.
719
720  Based on eq. 4.77 from paper by
721  Christoph Boeddeker et al.
722  https://arxiv.org/abs/1701.00392
723  See also
724  "Computation of eigenvalue and eigenvector derivatives
725  for a general complex-valued eigensystem" by Nico van der Aa.
726  As for now only distinct eigenvalue case is considered.
727  """
728  e = op.outputs[0]
729  compute_v = op.get_attr("compute_v")
730  # a = op.inputs[0], which satisfies
731  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
732  with ops.control_dependencies([grad_e, grad_v]):
733    if compute_v:
734      v = op.outputs[1]
735      vt = _linalg.adjoint(v)
736      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
737      # Notice that because of the term involving f, the gradient becomes
738      # infinite (or NaN in practice) when eigenvalues are not unique.
739      # Mathematically this should not be surprising, since for (k-fold)
740      # degenerate eigenvalues, the corresponding eigenvectors are only defined
741      # up to arbitrary rotation in a (k-dimensional) subspace.
742      f = array_ops.matrix_set_diag(
743          _SafeReciprocal(
744              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
745          array_ops.zeros_like(e))
746      f = math_ops.conj(f)
747      vgv = math_ops.matmul(vt, grad_v)
748      mid = array_ops.matrix_diag(grad_e)
749      diag_grad_part = array_ops.matrix_diag(
750          array_ops.matrix_diag_part(
751              math_ops.cast(math_ops.real(vgv), vgv.dtype)))
752      mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
753      # vt is formally invertible as long as the original matrix is
754      # diagonalizable. However, in practice, vt may
755      # be ill-conditioned when matrix original matrix is close to
756      # non-diagonalizable one
757      grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
758    else:
759      _, v = linalg_ops.eig(op.inputs[0])
760      vt = _linalg.adjoint(v)
761      # vt is formally invertible as long as the original matrix is
762      # diagonalizable. However, in practice, vt may
763      # be ill-conditioned when matrix original matrix is close to
764      # non-diagonalizable one
765      grad_a = linalg_ops.matrix_solve(
766          vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
767    return math_ops.cast(grad_a, op.inputs[0].dtype)
768
769
770@ops.RegisterGradient("SelfAdjointEigV2")
771def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
772  """Gradient for SelfAdjointEigV2."""
773  e = op.outputs[0]
774  compute_v = op.get_attr("compute_v")
775  # a = op.inputs[0], which satisfies
776  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
777  with ops.control_dependencies([grad_e, grad_v]):
778    if compute_v:
779      v = op.outputs[1]
780      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
781      # Notice that because of the term involving f, the gradient becomes
782      # infinite (or NaN in practice) when eigenvalues are not unique.
783      # Mathematically this should not be surprising, since for (k-fold)
784      # degenerate eigenvalues, the corresponding eigenvectors are only defined
785      # up to arbitrary rotation in a (k-dimensional) subspace.
786      f = array_ops.matrix_set_diag(
787          _SafeReciprocal(
788              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
789          array_ops.zeros_like(e))
790      grad_a = math_ops.matmul(
791          v,
792          math_ops.matmul(
793              array_ops.matrix_diag(grad_e) +
794              f * math_ops.matmul(v, grad_v, adjoint_a=True),
795              v,
796              adjoint_b=True))
797    else:
798      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
799      grad_a = math_ops.matmul(v,
800                               math_ops.matmul(
801                                   array_ops.matrix_diag(grad_e),
802                                   v,
803                                   adjoint_b=True))
804    # The forward op only depends on the lower triangular part of a, so here we
805    # symmetrize and take the lower triangle
806    grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
807    grad_a = array_ops.matrix_set_diag(grad_a,
808                                       0.5 * array_ops.matrix_diag_part(grad_a))
809    return grad_a
810
811
812@ops.RegisterGradient("Svd")
813def _SvdGrad(op, grad_s, grad_u, grad_v):
814  """Gradient for the singular value decomposition."""
815
816  # The derivation for the compute_uv=False case, and most of
817  # the derivation for the full_matrices=True case, are in
818  # Giles' paper (see reference at top of file).  A derivation for
819  # the full_matrices=False case is available at
820  # https://j-towns.github.io/papers/svd-derivative.pdf
821  # The derivation for complex valued SVD can be found in
822  # https://re-ra.xyz/misc/complexsvd.pdf or
823  # https://giggleliu.github.io/2019/04/02/einsumbp.html
824  a = op.inputs[0]
825  a_shape = a.get_shape().with_rank_at_least(2)
826  grad_s = math_ops.cast(grad_s, a.dtype)
827  grad_s_mat = array_ops.matrix_diag(grad_s)
828
829  if not op.get_attr("compute_uv"):
830    s, u, v = linalg_ops.svd(a, compute_uv=True)
831    grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
832    grad_a.set_shape(a_shape)
833    return grad_a
834
835  full_matrices = op.get_attr("full_matrices")
836
837  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
838  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
839  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
840  n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
841  batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
842      grad_v_shape[:-2])
843  a_shape = batch_shape.concatenate([m, n])
844
845  m = a_shape.dims[-2].value
846  n = a_shape.dims[-1].value
847  # TODO(rmlarsen): Make this work with placeholders.
848  if m is None or n is None:
849    raise NotImplementedError(
850        "SVD gradient has not been implemented for input with unknown "
851        "inner matrix shape.")
852
853  s = op.outputs[0]
854  u = op.outputs[1]
855  v = op.outputs[2]
856  s = math_ops.cast(s, a.dtype)
857
858  use_adjoint = False
859  if m > n:
860    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
861    # Hermitian transpose of the gradient at the end.
862    use_adjoint = True
863    m, n = n, m
864    u, v = v, u
865    grad_u, grad_v = grad_v, grad_u
866
867  with ops.control_dependencies([grad_s, grad_u, grad_v]):
868    if full_matrices and abs(m - n) > 1:
869      raise NotImplementedError(
870          "svd gradient is not implemented for abs(m - n) > 1 "
871          "when full_matrices is True")
872    s_mat = array_ops.matrix_diag(s)
873    s2 = math_ops.square(s)
874
875    # NOTICE: Because of the term involving f, the gradient becomes
876    # infinite (or NaN in practice) when singular values are not unique.
877    # Mathematically this should not be surprising, since for (k-fold)
878    # degenerate singular values, the corresponding singular vectors are
879    # only defined up a (k-dimensional) subspace. In practice, this can
880    # lead to numerical instability when singular values are close but not
881    # exactly equal.
882
883    s_shape = array_ops.shape(s)
884    f = array_ops.matrix_set_diag(
885        _SafeReciprocal(
886            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
887        array_ops.zeros_like(s))
888    s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
889
890    v1 = v[..., :, :m]
891    grad_v1 = grad_v[..., :, :m]
892
893    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
894    v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
895
896    f_u = f * u_gu
897    f_v = f * v_gv
898
899    term1_nouv = (
900        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
901        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
902
903    term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
904
905    if m == n:
906      grad_a_before_transpose = term1
907    else:
908      gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
909      gv1t_v1 = math_ops.matmul(gv1t, v1)
910      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
911
912      if full_matrices:
913        v2 = v[..., :, m:n]
914        grad_v2 = grad_v[..., :, m:n]
915
916        v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
917        term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
918
919      u_s_inv = math_ops.matmul(u, s_inv_mat)
920      term2 = math_ops.matmul(u_s_inv, term2_nous)
921
922      grad_a_before_transpose = term1 + term2
923
924    if a.dtype.is_complex:
925      eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
926      l = eye * v_gv
927      term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
928      term3 = 1 / 2. * math_ops.matmul(
929          u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
930
931      grad_a_before_transpose += term3
932
933    if use_adjoint:
934      grad_a = array_ops.matrix_transpose(
935          grad_a_before_transpose, conjugate=True)
936    else:
937      grad_a = grad_a_before_transpose
938
939    grad_a.set_shape(a_shape)
940    return grad_a
941
942
943def _LeftShift(x):
944  """Shifts next-to-last dimension to the left, adding zero on the right."""
945  rank = array_ops.rank(x)
946  zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
947  pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0)
948  return array_ops.pad(x[..., 1:, :], pad)
949
950
951def _RightShift(x):
952  """Shifts next-to-last dimension to the right, adding zero on the left."""
953  rank = array_ops.rank(x)
954  zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
955  pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0)
956  return array_ops.pad(x[..., :-1, :], pad)
957
958
959@ops.RegisterGradient("TridiagonalMatMul")
960def _TridiagonalMatMulGrad(op, grad):
961  """Gradient for TridiagonalMatMul."""
962  superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True)
963  maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True)
964  subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True)
965  rhs_conj = math_ops.conj(op.inputs[3])
966
967  superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1)
968  maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1)
969  subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1)
970  rhs_grad = _RightShift(superdiag_conj * grad) + \
971      maindiag_conj * grad + _LeftShift(subdiag_conj * grad)
972
973  superdiag_grad = array_ops.expand_dims(superdiag_grad, -2)
974  maindiag_grad = array_ops.expand_dims(maindiag_grad, -2)
975  subdiag_grad = array_ops.expand_dims(subdiag_grad, -2)
976
977  return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad
978
979
980@ops.RegisterGradient("TridiagonalSolve")
981def _TridiagonalSolveGrad(op, grad):
982  """Gradient for TridiagonalSolveGrad."""
983  diags = op.inputs[0]
984  x = op.outputs[0]
985  partial_pivoting = op.get_attr("partial_pivoting")
986  perturb_singular = op.get_attr("perturb_singular")
987
988  # Transposing the matrix within tridiagonal_solve kernel by interchanging
989  # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
990  # paddings required by cusparse*gtsv routines.
991  # So constructing the transposed matrix in Python.
992  diags_transposed = _TransposeTridiagonalMatrix(diags)
993
994  grad_rhs = linalg_ops.tridiagonal_solve(
995      diags_transposed,
996      grad,
997      partial_pivoting=partial_pivoting,
998      perturb_singular=perturb_singular)
999  grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x)  # pylint: disable=invalid-unary-operand-type
1000  return grad_diags, grad_rhs
1001
1002
1003def _TransposeTridiagonalMatrix(diags):
1004  """Transposes a tridiagonal matrix.
1005
1006  Args:
1007    diags: the diagonals of the input matrix in the compact form (see
1008      linalg_ops.tridiagonal_solve).
1009
1010  Returns:
1011    Diagonals of the transposed matrix in the compact form.
1012  """
1013
1014  diag = diags[..., 1, :]
1015
1016  if diags.shape.is_fully_defined():
1017    # For fully defined tensor we can concat with a tensor of zeros, which is
1018    # faster than using array_ops.pad().
1019    zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype)
1020    superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1)
1021    subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1)
1022  else:
1023    rank = array_ops.rank(diags)
1024    zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1025    superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])),
1026                                     axis=0)
1027    superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad)
1028    subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])),
1029                                   axis=0)
1030    subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad)
1031  return array_ops.stack([superdiag, diag, subdiag], axis=-2)
1032
1033
1034def _MatmulExtractingThreeDiagonals(x, y_tr):
1035  """Multiplies matrices and extracts three diagonals from the product.
1036
1037  With sizes M x K and K x M, this function takes O(MK) time and O(M) space,
1038  while using math_ops.matmul, and then extracting the diagonals would take
1039  O(M^2 K) time and O(M^2) space.
1040
1041  Args:
1042    x: first matrix
1043    y_tr: second matrix transposed
1044
1045  Returns:
1046    Diagonals of the product in compact format (see
1047    linalg_ops.tridiagonal_solve)
1048
1049  """
1050  diag = math_ops.reduce_sum(x * y_tr, axis=-1)
1051
1052  if y_tr.shape.is_fully_defined():
1053    zeros = array_ops.zeros(
1054        list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype)
1055    superdiag = math_ops.reduce_sum(
1056        x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1)
1057    subdiag = math_ops.reduce_sum(
1058        x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1)
1059  else:
1060    rank = array_ops.rank(y_tr)
1061    zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1062    superdiag_pad = array_ops.concat(
1063        (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0)
1064    superdiag = math_ops.reduce_sum(
1065        x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1)
1066    subdiag_pad = array_ops.concat(
1067        (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0)
1068    subdiag = math_ops.reduce_sum(
1069        x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1)
1070  return array_ops.stack([superdiag, diag, subdiag], axis=-2)
1071