1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in linalg_ops.py. 16 17Useful reference for derivative formulas is (Mike Giles, 2008). 18 19Ionescu et al. (2015) provide a detailed derivation of formulas for 20backpropagating through spectral layers (SVD and Eig). 21 22References: 23 An extended collection of matrix derivative results for 24 forward and reverse mode automatic differentiation: 25 [Mike Giles, 2008] 26 (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124) 27 ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf)) 28 Matrix Backpropagation for Deep Networks with Structured Layers 29 [Ionescu et al., 2015] 30 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html) 31 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf)) 32 Training Deep Networks with Structured Layers by Matrix Backpropagation: 33 [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838) 34 ([pdf](https://arxiv.org/pdf/1509.07838.pdf)) 35""" 36from __future__ import absolute_import 37from __future__ import division 38from __future__ import print_function 39 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import ops 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gen_linalg_ops 45from tensorflow.python.ops import linalg_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops.linalg import linalg_impl as _linalg 48 49 50@ops.RegisterGradient("MatrixInverse") 51def _MatrixInverseGrad(op, grad): 52 """Gradient for MatrixInverse.""" 53 ainv = op.outputs[0] 54 return -math_ops.matmul( 55 ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True) 56 57 58@ops.RegisterGradient("Einsum") 59def _EinsumGrad(op, grad): 60 """Gradient for Einsum.""" 61 ellipsis = "..." 62 63 def _GetAxisFromLabel(subscripts, label): 64 """Returns the axis (possibly negative) corresponding to a label. 65 66 Returns the axis index of the axis label if it is before an ellipsis (or if 67 the ellipsis is not present), and the negative index if it occurs after the 68 ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`. 69 70 For multiple occurrences, returns the leftmost one. If not found, returns 71 None. 72 73 Args: 74 subscripts: A string denoting the einsum subscript (e.g. `ab...cd`) 75 label: The single character axis label. 76 """ 77 splits = subscripts.split(ellipsis) 78 index = splits[0].find(label) 79 if index != -1: 80 return index 81 if len(splits) < 2: 82 return None 83 index = splits[1].find(label) 84 if index != -1: 85 return index - len(splits[1]) 86 return None 87 88 def _GetBcastSubshape(subscripts): 89 """Returns a tuple denoting the slice mapping to ellipsis. 90 91 For a given subscript, returns a tuple (start, end) denoting the start 92 axis index and the (negative) end axis index respectively. For any input 93 Tensor `x` described by the subscript, `x[start:end]` would be the slice 94 represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`. 95 96 If ellipsis is not present in `subscripts`, returns `(0, 0)`. 97 98 Args: 99 subscripts: A string denoting the einsum subscript. 100 """ 101 start = subscripts.find(ellipsis) 102 if start == -1: 103 return 0, 0 104 remaining = len(subscripts) - (start + len(ellipsis)) 105 end = -remaining if remaining > 0 else None 106 return start, end 107 108 def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts): 109 """Returns reduced subscripts and their corresponding dimensions and axes. 110 111 Given a set of axis labels, returns their concatenated subscript, their 112 corresponding dimensions from input_shape, and their corresponding axes. 113 Note that the concatenated subscript `reduced_subs` may have axis labels 114 from `reduced_label_set` in any order. For example, for the reduced label 115 set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns 116 subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`. 117 118 Args: 119 reduced_label_set: Set of axis labels which appear in `subscripts`. 120 input_shape: A `Tensor` representing the shape of the einsum operand 121 corresponding to `subscripts`. 122 subscripts: A string denoting the einsum subscript. 123 124 Returns: 125 reduced_subs: Subscripts formed by a concatenation of labels in 126 `reduced_label_set`. 127 reduced_dims: Dimensions from `input_shape` corresponding to each label 128 in `reduced_subs`. 129 reduced_axes: Axes described by `subscripts` corresponding to each label 130 in `reduced_subs`. If there are multiple occurrences in `subscripts`, 131 we consider only the leftmost one. 132 133 """ 134 # Concatenate the sequence of reduced axis labels. 135 reduced_subs = "".join(list(reduced_label_set)) 136 # Get the axis (may be positive, negative or zero) for each of the reduced 137 # labels. If the same label appears multiple times, get the left-most axis. 138 reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs] 139 # Get the corresponding dimensions for each reduced axis. 140 reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes]) 141 return reduced_subs, reduced_dims, reduced_axes 142 143 def _GetGradReduced(output_grad, output_subs, input_subs, input_shape, 144 reduced_label_set): 145 """Returns the gradient wrt input for a unary einsum with reductions. 146 147 Args: 148 output_grad: The gradient wrt the output of a unary einsum operation. 149 output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). 150 input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). 151 input_shape: A `Tensor` representing the shape of the input operand. 152 reduced_label_set: The set of axis labels appearing in `input_subs` but 153 not in `output_subs`. 154 """ 155 # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and 156 # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced 157 # subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. 158 reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts( 159 reduced_label_set, input_shape, input_subs) 160 # Whether either the input or the output subscripts have a repeated label. 161 # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". 162 has_repeated_labels = ( 163 len(set(input_subs)) + len(set(output_subs)) < 164 len(input_subs) + len(output_subs)) 165 # Compute the input subscripts without the reduced axis labels, e.g. "aac" 166 # for the equation "aabbcd->ca". 167 input_subs_without_reduced_labels = "".join( 168 [s for s in input_subs if s not in reduced_label_set]) 169 170 # The gradient wrt the input for the equation "abc->ac" (or, equivalently 171 # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times 172 # along axis 1, where label 'b' represents a dimension of size N. 173 # 174 # If we're not dealing with repeated labels, and the non-reduced labels 175 # doesn't need to be transposed, then just tiling is enough and there is no 176 # need to call another einsum. For example, tiling is sufficient for 177 # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or 178 # "abc->ca" (transpose), we'd need another einsum operation after tiling. 179 if (not has_repeated_labels and 180 input_subs_without_reduced_labels == output_subs): 181 # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. 182 # for the equation "abcd->ac" with input shape [2,5,3,4], we get the 183 # reduced shape [2,1,3,1]. 184 reduced_shape = math_ops.reduced_shape( 185 input_shape, ops.convert_to_tensor(reduced_axes)) 186 # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to 187 # the shape [2,5,3,4] results in the gradient wrt "abcd". 188 return array_ops.broadcast_to( 189 array_ops.reshape(output_grad, reduced_shape), input_shape) 190 191 # If we *do* have traces or transpose operations, then prepend the extra 192 # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd 193 # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". 194 # 195 # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. 196 # This is the shape of the intermediate "bdca". 197 grad_shape_with_reduced_labels = array_ops.concat( 198 [reduced_dims, array_ops.shape(output_grad)], axis=0) 199 # Obtain the output shape of the reduction-only equation "bdca->ca" as if 200 # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we 201 # just have to prepend that many 1s to the output shape. 202 reduced_shape = ( 203 array_ops.concat([ 204 array_ops.ones(len(reduced_label_set), dtype=dtypes.int32), 205 array_ops.shape(output_grad) 206 ], 207 axis=0)) 208 # Compute the VJP for the intermediate (viz. "bdca->ca") for which 209 # broadcasting is sufficient. 210 broadcasted_grad = array_ops.broadcast_to( 211 array_ops.reshape(output_grad, reduced_shape), 212 grad_shape_with_reduced_labels) 213 # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use 214 # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd") 215 # since the output axis labels now appear in the input subscripts. 216 return gen_linalg_ops.einsum([broadcasted_grad], 217 "{}->{}".format(reduced_subs + output_subs, 218 input_subs)) 219 220 def _GetGradWrt(output_grad, other_operand, input_shape, input_subs, 221 other_subs, output_subs): 222 """Returns the gradient wrt an input operand for a binary einsum. 223 224 This function does not handle (un)broadcasting. This must be done separately 225 on the returned gradient. 226 227 Args: 228 output_grad: The gradient wrt the output of a binary einsum operation. 229 other_operand: The complementary `Tensor` operand i.e. which is not the 230 input operand. 231 input_shape: A `Tensor` representing the shape of input operand. 232 input_subs: The subscripts of the input operand. 233 other_subs: The subscripts of the complementary operand. 234 output_subs: The output subscripts. 235 """ 236 # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y), 237 # where the equation involves only Tensor contractions, generalized traces 238 # and transposes, the input gradients are given by the vector-jacobian 239 # products (VJPs): 240 # 241 # grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z) 242 # grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z} 243 # 244 # where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs 245 # x and y and grad_wrt_z is the given gradient with respect to output z. 246 # 247 # Proof: For unary einsum equations involving only transpose ("ij->ji") and 248 # traces ("ii->i"), the linear mapping's Jacobian at input x is given 249 # by the function itself. We can verify that the linear map given by the 250 # VJP are einsums with the equations "ji->ij" and "i->ii" respectively, 251 # where the latter represents 'un-tracing', or filling the diagonal with 252 # the input axis and non-diagonal entries are zeros. 253 # Furthermore, recall that matrix multiplication, which is 254 # represented by the equation "ab,bc->ac", has its VJPs given by the 255 # einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example 256 # https://math.stackexchange.com/a/2755680). Combined with transposes and 257 # traces we can rewrite Tensor contractions as regular matrix 258 # multiplication. Since each of these operations have their VJPs described 259 # by einsums of the required pattern, the result follows. 260 # 261 # Accordingly, einsum operations except for those with reductions, e.g. 262 # "abc,cd->ad" have their VJPs defined by: 263 # "{output_subs},{other_subs}->{input_subs}". 264 # 265 # But if there is a reduction, this would lead to the equation "ad,cd->abc" 266 # which is invalid because the reduced axis label 'b' is present in the 267 # output but not in any of the inputs. Therefore, we compute the VJP in two 268 # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of 269 # "abc->ac" or, equivalently, reduce_sum(..., axis=1). 270 # 271 # Compute the set of input axis labels which doesn't appear in either the 272 # output subscripts or the other operand's subscript. E.g. the set {'b'} for 273 # the equation "abc,cd->ad". 274 reduced_label_set = set(input_subs).difference( 275 set(output_subs + other_subs + ".")) 276 # Obtain the input subscripts with the reduced axis labels removed. E.g. 277 # "ac" in the above example. 278 left_subs = "".join(s for s in input_subs if s not in reduced_label_set) 279 280 # Compute the gradient wrt the input, without accounting for the operation 281 # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad". 282 grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand], 283 "{},{}->{}".format( 284 output_subs, other_subs, 285 left_subs)) 286 # If the reduced_label_set is empty, then we already have the gradient 287 # wrt the input. 288 if not reduced_label_set: 289 return grad_reduced 290 # Otherwise, we currently have the gradient wrt the output of the reduction 291 # operation "abc->ac". Invoke the subroutine for the gradient for unary 292 # einsum with reductions. 293 return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, 294 reduced_label_set) 295 296 equation = op.get_attr("equation") 297 if isinstance(equation, bytes): 298 equation = equation.decode() 299 input_subs, output_subs = equation.split("->") 300 301 if len(op.inputs) == 1: 302 # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the 303 # input (VJP) is given by the reversed equation: 304 # grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z) 305 # (See the justification in _GetGradWrt). This is valid unless there are 306 # reduced axis labels; i.e. axis labels appearing in the input but not in 307 # the output subscripts. 308 input_shape = array_ops.shape(op.inputs[0]) 309 # Find the axis labels which appear only in the input. 310 reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis)) 311 if not reduced_label_set: 312 # Return the einsum given by the reversed equation, since we don't have 313 # reduced axes. 314 return gen_linalg_ops.einsum([grad], 315 "{}->{}".format(output_subs, input_subs)) 316 # We do have reduced axes, so we invoke the subroutine for reduced unary 317 # einsums. 318 return _GetGradReduced(grad, output_subs, input_subs, input_shape, 319 reduced_label_set) 320 321 x_subs, y_subs = input_subs.split(",") 322 # Add ellipsis for broadcasted dimensions if any operand does not have it. 323 # This is because the equation "...ij,jk->ik" may be valid if the 0th input's 324 # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid 325 # because only the output subscripts contain ellipsis. 326 if ellipsis in output_subs: 327 if ellipsis not in x_subs: 328 x_subs += ellipsis 329 if ellipsis not in y_subs: 330 y_subs += ellipsis 331 332 # Obtain the gradients wrt the inputs x and y, without taking into account 333 # the unbroadcasting. 334 x, y = op.inputs[0], op.inputs[1] 335 if grad.dtype.is_complex: 336 x = math_ops.conj(x) 337 y = math_ops.conj(y) 338 339 x_shape = array_ops.shape(x) 340 y_shape = array_ops.shape(y) 341 grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs) 342 grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs) 343 344 if ellipsis not in output_subs: 345 # If no ellipsis in the output; then no need to unbroadcast. 346 return grad_x, grad_y 347 348 # Below we handle the case that broadcasting between x and y was necessary, 349 # with x and y having possibly different batch shapes. 350 351 # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c' 352 # and shape of rank 10; the range [3:-1] denotes the broadcasted axes. 353 bx_start, bx_end = _GetBcastSubshape(x_subs) 354 by_start, by_end = _GetBcastSubshape(y_subs) 355 # If the static batch shapes are equal, we don't need to unbroadcast. 356 x_shape_static = x.get_shape() 357 y_shape_static = y.get_shape() 358 if (x_shape_static.is_fully_defined() and 359 y_shape_static.is_fully_defined() and 360 x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]): 361 return grad_x, grad_y 362 363 # Sum the gradient across the broadcasted axes. 364 rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end], 365 y_shape[by_start:by_end]) 366 grad_x = array_ops.reshape( 367 math_ops.reduce_sum(grad_x, bx_start + rx), x_shape) 368 grad_y = array_ops.reshape( 369 math_ops.reduce_sum(grad_y, by_start + ry), y_shape) 370 return grad_x, grad_y 371 372 373@ops.RegisterGradient("MatrixDeterminant") 374def _MatrixDeterminantGrad(op, grad): 375 """Gradient for MatrixDeterminant.""" 376 a = op.inputs[0] 377 c = op.outputs[0] 378 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 379 multipliers = array_ops.reshape(grad * c, 380 array_ops.concat([array_ops.shape(c), [1, 1]], 381 0)) 382 return multipliers * a_adj_inv 383 384 385@ops.RegisterGradient("MatrixSquareRoot") 386def _MatrixSquareRootGrad(op, grad): 387 """Gradient for MatrixSquareRoot.""" 388 389 # Let A be an m x m square matrix (or batch of matrices) 390 # Let R = sqrtm(A) 391 # By definition, A = RR 392 # Take the differential: dA = d(RR) = RdR + dRR 393 # Solve the resulting Sylvester equation for dR 394 395 # Used to find Kronecker products within the Sylvester equation 396 def _KroneckerProduct(b1, b2): 397 """Computes the Kronecker product of two batches of square matrices.""" 398 b1_shape = array_ops.shape(b1) 399 b2_shape = array_ops.shape(b2) 400 b1_order = b1_shape[-1] 401 b2_order = b2_shape[-1] 402 403 shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)] 404 shape_slice = array_ops.slice(b1_shape, [0], 405 shape_slice_size) # Same for both batches 406 b1_reshape_shape = array_ops.concat( 407 [shape_slice, [b1_order], [1], [b1_order], [1]], 0) 408 b2_reshape_shape = array_ops.concat( 409 [shape_slice, [1], [b2_order], [1], [b2_order]], 0) 410 411 b1_reshape = array_ops.reshape(b1, b1_reshape_shape) 412 b2_reshape = array_ops.reshape(b2, b2_reshape_shape) 413 414 order_prod = b1_order * b2_order 415 kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0) 416 return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape) 417 418 sqrtm = op.outputs[0] # R 419 shape = array_ops.shape(sqrtm) 420 order = shape[-1] # m 421 matrix_count = math_ops.reduce_prod(shape[0:-2]) 422 423 # Get batch of m x m identity matrices 424 eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix 425 eye_flat = array_ops.reshape(eye, [-1]) 426 eye_tiled = array_ops.tile(eye_flat, [matrix_count]) 427 eye_batch = array_ops.reshape(eye_tiled, shape) 428 429 # The transpose of R is taken in the k1 term instead of k2 in 430 # order to prevent redundant transposition of R (i.e. (R')' = R) 431 sqrtm_transpose = array_ops.matrix_transpose(sqrtm) 432 k1 = _KroneckerProduct(eye_batch, sqrtm_transpose) 433 k2 = _KroneckerProduct(sqrtm, eye_batch) 434 ksum = math_ops.add(k1, k2) 435 436 # Vectorize dA 437 shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)] 438 shape_slice = array_ops.slice(shape, [0], shape_slice_size) 439 shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0) 440 vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da) 441 442 # Solve for vec(dR) 443 vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da) 444 445 # Solve for dR by inverse vectorizing vec(dR) 446 dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape) 447 return array_ops.matrix_transpose(dsqrtm_transpose) 448 449 450@ops.RegisterGradient("LogMatrixDeterminant") 451def _LogMatrixDeterminantGrad(op, _, grad_b): 452 """Gradient for LogMatrixDeterminant.""" 453 a = op.inputs[0] 454 c = op.outputs[1] 455 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 456 multipliers = array_ops.reshape( 457 grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) 458 return multipliers * a_adj_inv 459 460 461@ops.RegisterGradient("Cholesky") 462def _CholeskyGrad(op, grad): 463 """Gradient for Cholesky.""" 464 465 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} 466 l = op.outputs[0] 467 num_rows = array_ops.shape(l)[-1] 468 batch_shape = array_ops.shape(l)[:-2] 469 l_inverse = linalg_ops.matrix_triangular_solve(l, 470 linalg_ops.eye( 471 num_rows, 472 batch_shape=batch_shape, 473 dtype=l.dtype)) 474 475 middle = math_ops.matmul(l, grad, adjoint_a=True) 476 middle = array_ops.matrix_set_diag(middle, 477 0.5 * array_ops.matrix_diag_part(middle)) 478 middle = array_ops.matrix_band_part(middle, -1, 0) 479 480 grad_a = math_ops.matmul( 481 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) 482 483 grad_a += _linalg.adjoint(grad_a) 484 return grad_a * 0.5 485 486 487@ops.RegisterGradient("Qr") 488def _QrGrad(op, dq, dr): 489 """Gradient for Qr.""" 490 491 # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 492 # QR and LQ Decomposition Matrix Backpropagation Algorithms for 493 # Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation 494 q, r = op.outputs 495 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or 496 r.shape.as_list()[-1] is None): 497 raise NotImplementedError("QrGrad not implemented with dynamic shapes.") 498 if (r.shape.dims[-2].value > r.shape.dims[-1].value and 499 q.shape.dims[-2].value == q.shape.dims[-1].value): 500 raise NotImplementedError("QrGrad not implemented when nrows > ncols " 501 "and full_matrices is true.") 502 503 def _TriangularSolve(x, r): 504 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" 505 return _linalg.adjoint( 506 linalg_ops.matrix_triangular_solve( 507 r, _linalg.adjoint(x), lower=False, adjoint=False)) 508 509 def _QrGradSquareAndDeepMatrices(q, r, dq, dr): 510 """Gradient for matrix orders num_rows >= num_cols 511 and full_matrices is false. 512 """ 513 qdq = math_ops.matmul(q, dq, adjoint_a=True) 514 qdq_ = qdq - _linalg.adjoint(qdq) 515 rdr = math_ops.matmul(r, dr, adjoint_b=True) 516 rdr_ = rdr - _linalg.adjoint(rdr) 517 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) 518 519 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) 520 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) 521 ret = grad_a + grad_b 522 523 if q.dtype.is_complex: 524 # need to add a correction to the gradient formula for complex case 525 m = rdr - _linalg.adjoint(qdq) 526 eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m)) 527 correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype) 528 ret = ret + _TriangularSolve( 529 math_ops.matmul(q, _linalg.adjoint(correction)), r) 530 531 return ret 532 533 num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1] 534 535 if num_rows >= num_cols: 536 return _QrGradSquareAndDeepMatrices(q, r, dq, dr) 537 538 # Partition a = [x, y], r = [u, v] and reduce to the square case 539 a = op.inputs[0] 540 y = a[..., :, num_rows:] 541 u = r[..., :, :num_rows] 542 dv = dr[..., :, num_rows:] 543 du = dr[..., :, :num_rows] 544 dy = math_ops.matmul(q, dv) 545 dx = _QrGradSquareAndDeepMatrices(q, u, 546 dq + math_ops.matmul(y, dv, adjoint_b=True), 547 du) 548 return array_ops.concat([dx, dy], axis=-1) 549 550 551@ops.RegisterGradient("MatrixSolve") 552def _MatrixSolveGrad(op, grad): 553 """Gradient for MatrixSolve.""" 554 a = op.inputs[0] 555 adjoint_a = op.get_attr("adjoint") 556 c = op.outputs[0] 557 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) 558 if adjoint_a: 559 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 560 else: 561 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 562 return (grad_a, grad_b) 563 564 565@ops.RegisterGradient("MatrixSolveLs") 566def _MatrixSolveLsGrad(op, grad): 567 """Gradients for MatrixSolveLs.""" 568 569 # TODO(rmlarsen): The implementation could be more efficient: 570 # a) Output the Cholesky factorization from forward op instead of 571 # recomputing it here. 572 # b) Implement a symmetric rank-k update op instead of computing 573 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. 574 575 def _Overdetermined(op, grad): 576 """Gradients for the overdetermined case of MatrixSolveLs. 577 578 This is the backprop for the solution to the normal equations of the first 579 kind: 580 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B 581 which solve the least squares problem 582 min ||A * X - B||_F^2 + lambda ||X||_F^2. 583 """ 584 a = op.inputs[0] 585 b = op.inputs[1] 586 x = op.outputs[0] 587 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 588 # pylint: disable=protected-access 589 chol = linalg_ops._RegularizedGramianCholesky( 590 a, l2_regularizer=l2_regularizer, first_kind=True) 591 # pylint: enable=protected-access 592 # Temporary z = (A^T * A + lambda * I)^{-1} * grad. 593 z = linalg_ops.cholesky_solve(chol, grad) 594 xzt = math_ops.matmul(x, z, adjoint_b=True) 595 zx_sym = xzt + array_ops.matrix_transpose(xzt) 596 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) 597 grad_b = math_ops.matmul(a, z) 598 return (grad_a, grad_b, None) 599 600 def _Underdetermined(op, grad): 601 """Gradients for the underdetermined case of MatrixSolveLs. 602 603 This is the backprop for the solution to the normal equations of the second 604 kind: 605 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B 606 that (for lambda=0) solve the least squares problem 607 min ||X||_F subject to A*X = B. 608 """ 609 a = op.inputs[0] 610 b = op.inputs[1] 611 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 612 # pylint: disable=protected-access 613 chol = linalg_ops._RegularizedGramianCholesky( 614 a, l2_regularizer=l2_regularizer, first_kind=False) 615 # pylint: enable=protected-access 616 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) 617 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. 618 tmp = linalg_ops.cholesky_solve(chol, b) 619 a1 = math_ops.matmul(tmp, a, adjoint_a=True) 620 a1 = -math_ops.matmul(grad_b, a1) 621 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) 622 a2 = math_ops.matmul(tmp, a2, adjoint_b=True) 623 grad_a = a1 + a2 624 return (grad_a, grad_b, None) 625 626 fast = op.get_attr("fast") 627 if fast is False: 628 raise ValueError("Gradient not defined for fast=False") 629 matrix_shape = op.inputs[0].get_shape()[-2:] 630 if matrix_shape.is_fully_defined(): 631 if matrix_shape[-2] >= matrix_shape[-1]: 632 return _Overdetermined(op, grad) 633 else: 634 return _Underdetermined(op, grad) 635 else: 636 # We have to defer determining the shape to runtime and use 637 # conditional execution of the appropriate graph. 638 matrix_shape = array_ops.shape(op.inputs[0])[-2:] 639 return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], 640 lambda: _Overdetermined(op, grad), 641 lambda: _Underdetermined(op, grad)) 642 643 644@ops.RegisterGradient("BandedTriangularSolve") 645def _BandedTriangularSolveGrad(op, grad): 646 """Gradient for BandedTriangularSolve.""" 647 a = op.inputs[0] 648 b = op.inputs[1] 649 num_bands = array_ops.shape(a)[-2] 650 adjoint_a = op.get_attr("adjoint") 651 lower_a = op.get_attr("lower") 652 c = op.outputs[0] 653 grad_b = linalg_ops.banded_triangular_solve( 654 a, grad, lower=lower_a, adjoint=not adjoint_a) 655 if adjoint_a: 656 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 657 else: 658 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 659 if lower_a: 660 grad_a = array_ops.matrix_diag_part( 661 grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT") 662 else: 663 grad_a = array_ops.matrix_diag_part( 664 grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT") 665 # If the static batch shapes are equal, we don't need to unbroadcast. 666 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 667 a.shape[:-2] == b.shape[:-2]): 668 return grad_a, grad_b 669 a_shape = array_ops.shape(a) 670 b_shape = array_ops.shape(b) 671 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 672 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 673 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 674 return grad_a, grad_b 675 676 677@ops.RegisterGradient("MatrixTriangularSolve") 678def _MatrixTriangularSolveGrad(op, grad): 679 """Gradient for MatrixTriangularSolve.""" 680 a = op.inputs[0] 681 b = op.inputs[1] 682 adjoint_a = op.get_attr("adjoint") 683 lower_a = op.get_attr("lower") 684 c = op.outputs[0] 685 grad_b = linalg_ops.matrix_triangular_solve( 686 a, grad, lower=lower_a, adjoint=not adjoint_a) 687 if adjoint_a: 688 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 689 else: 690 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 691 if lower_a: 692 grad_a = array_ops.matrix_band_part(grad_a, -1, 0) 693 else: 694 grad_a = array_ops.matrix_band_part(grad_a, 0, -1) 695 # If the static batch shapes are equal, we don't need to unbroadcast. 696 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 697 a.shape[:-2] == b.shape[:-2]): 698 return grad_a, grad_b 699 a_shape = array_ops.shape(a) 700 b_shape = array_ops.shape(b) 701 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 702 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 703 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 704 return grad_a, grad_b 705 706 707# To avoid nan in cases with degenerate eigenvalues or 708# degenerate/zero singular values in calculations of 709# f and s_inv_mat, we introduce a Lorentz broadening. 710def _SafeReciprocal(x, epsilon=1E-20): 711 return x * math_ops.reciprocal(x * x + epsilon) 712 713 714@ops.RegisterGradient("Eig") 715def _EigGrad(op, grad_e, grad_v): 716 """Gradient for Eig. 717 718 Based on eq. 4.77 from paper by 719 Christoph Boeddeker et al. 720 https://arxiv.org/abs/1701.00392 721 See also 722 "Computation of eigenvalue and eigenvector derivatives 723 for a general complex-valued eigensystem" by Nico van der Aa. 724 As for now only distinct eigenvalue case is considered. 725 """ 726 e = op.outputs[0] 727 compute_v = op.get_attr("compute_v") 728 # a = op.inputs[0], which satisfies 729 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 730 with ops.control_dependencies([grad_e, grad_v]): 731 if compute_v: 732 v = op.outputs[1] 733 vt = _linalg.adjoint(v) 734 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 735 # Notice that because of the term involving f, the gradient becomes 736 # infinite (or NaN in practice) when eigenvalues are not unique. 737 # Mathematically this should not be surprising, since for (k-fold) 738 # degenerate eigenvalues, the corresponding eigenvectors are only defined 739 # up to arbitrary rotation in a (k-dimensional) subspace. 740 f = array_ops.matrix_set_diag( 741 _SafeReciprocal( 742 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 743 array_ops.zeros_like(e)) 744 f = math_ops.conj(f) 745 vgv = math_ops.matmul(vt, grad_v) 746 mid = array_ops.matrix_diag(grad_e) 747 diag_grad_part = array_ops.matrix_diag( 748 array_ops.matrix_diag_part( 749 math_ops.cast(math_ops.real(vgv), vgv.dtype))) 750 mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part)) 751 # vt is formally invertible as long as the original matrix is 752 # diagonalizable. However, in practice, vt may 753 # be ill-conditioned when matrix original matrix is close to 754 # non-diagonalizable one 755 grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt)) 756 else: 757 _, v = linalg_ops.eig(op.inputs[0]) 758 vt = _linalg.adjoint(v) 759 # vt is formally invertible as long as the original matrix is 760 # diagonalizable. However, in practice, vt may 761 # be ill-conditioned when matrix original matrix is close to 762 # non-diagonalizable one 763 grad_a = linalg_ops.matrix_solve( 764 vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt)) 765 return math_ops.cast(grad_a, op.inputs[0].dtype) 766 767 768@ops.RegisterGradient("SelfAdjointEigV2") 769def _SelfAdjointEigV2Grad(op, grad_e, grad_v): 770 """Gradient for SelfAdjointEigV2.""" 771 e = op.outputs[0] 772 compute_v = op.get_attr("compute_v") 773 # a = op.inputs[0], which satisfies 774 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 775 with ops.control_dependencies([grad_e, grad_v]): 776 if compute_v: 777 v = op.outputs[1] 778 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 779 # Notice that because of the term involving f, the gradient becomes 780 # infinite (or NaN in practice) when eigenvalues are not unique. 781 # Mathematically this should not be surprising, since for (k-fold) 782 # degenerate eigenvalues, the corresponding eigenvectors are only defined 783 # up to arbitrary rotation in a (k-dimensional) subspace. 784 f = array_ops.matrix_set_diag( 785 _SafeReciprocal( 786 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 787 array_ops.zeros_like(e)) 788 grad_a = math_ops.matmul( 789 v, 790 math_ops.matmul( 791 array_ops.matrix_diag(grad_e) + 792 f * math_ops.matmul(v, grad_v, adjoint_a=True), 793 v, 794 adjoint_b=True)) 795 else: 796 _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) 797 grad_a = math_ops.matmul(v, 798 math_ops.matmul( 799 array_ops.matrix_diag(grad_e), 800 v, 801 adjoint_b=True)) 802 # The forward op only depends on the lower triangular part of a, so here we 803 # symmetrize and take the lower triangle 804 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) 805 grad_a = array_ops.matrix_set_diag(grad_a, 806 0.5 * array_ops.matrix_diag_part(grad_a)) 807 return grad_a 808 809 810@ops.RegisterGradient("Svd") 811def _SvdGrad(op, grad_s, grad_u, grad_v): 812 """Gradient for the singular value decomposition.""" 813 814 # The derivation for the compute_uv=False case, and most of 815 # the derivation for the full_matrices=True case, are in 816 # Giles' paper (see reference at top of file). A derivation for 817 # the full_matrices=False case is available at 818 # https://j-towns.github.io/papers/svd-derivative.pdf 819 # The derivation for complex valued SVD can be found in 820 # https://re-ra.xyz/misc/complexsvd.pdf or 821 # https://giggleliu.github.io/2019/04/02/einsumbp.html 822 a = op.inputs[0] 823 a_shape = a.get_shape().with_rank_at_least(2) 824 grad_s = math_ops.cast(grad_s, a.dtype) 825 grad_s_mat = array_ops.matrix_diag(grad_s) 826 827 if not op.get_attr("compute_uv"): 828 s, u, v = linalg_ops.svd(a, compute_uv=True) 829 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) 830 grad_a.set_shape(a_shape) 831 return grad_a 832 833 full_matrices = op.get_attr("full_matrices") 834 835 grad_u_shape = grad_u.get_shape().with_rank_at_least(2) 836 grad_v_shape = grad_v.get_shape().with_rank_at_least(2) 837 m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) 838 n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) 839 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( 840 grad_v_shape[:-2]) 841 a_shape = batch_shape.concatenate([m, n]) 842 843 m = a_shape.dims[-2].value 844 n = a_shape.dims[-1].value 845 # TODO(rmlarsen): Make this work with placeholders. 846 if m is None or n is None: 847 raise NotImplementedError( 848 "SVD gradient has not been implemented for input with unknown " 849 "inner matrix shape.") 850 851 s = op.outputs[0] 852 u = op.outputs[1] 853 v = op.outputs[2] 854 s = math_ops.cast(s, a.dtype) 855 856 use_adjoint = False 857 if m > n: 858 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the 859 # Hermitian transpose of the gradient at the end. 860 use_adjoint = True 861 m, n = n, m 862 u, v = v, u 863 grad_u, grad_v = grad_v, grad_u 864 865 with ops.control_dependencies([grad_s, grad_u, grad_v]): 866 if full_matrices and abs(m - n) > 1: 867 raise NotImplementedError( 868 "svd gradient is not implemented for abs(m - n) > 1 " 869 "when full_matrices is True") 870 s_mat = array_ops.matrix_diag(s) 871 s2 = math_ops.square(s) 872 873 # NOTICE: Because of the term involving f, the gradient becomes 874 # infinite (or NaN in practice) when singular values are not unique. 875 # Mathematically this should not be surprising, since for (k-fold) 876 # degenerate singular values, the corresponding singular vectors are 877 # only defined up a (k-dimensional) subspace. In practice, this can 878 # lead to numerical instability when singular values are close but not 879 # exactly equal. 880 881 s_shape = array_ops.shape(s) 882 f = array_ops.matrix_set_diag( 883 _SafeReciprocal( 884 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), 885 array_ops.zeros_like(s)) 886 s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) 887 888 v1 = v[..., :, :m] 889 grad_v1 = grad_v[..., :, :m] 890 891 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) 892 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) 893 894 f_u = f * u_gu 895 f_v = f * v_gv 896 897 term1_nouv = ( 898 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + 899 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) 900 901 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) 902 903 if m == n: 904 grad_a_before_transpose = term1 905 else: 906 gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) 907 gv1t_v1 = math_ops.matmul(gv1t, v1) 908 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) 909 910 if full_matrices: 911 v2 = v[..., :, m:n] 912 grad_v2 = grad_v[..., :, m:n] 913 914 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) 915 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) 916 917 u_s_inv = math_ops.matmul(u, s_inv_mat) 918 term2 = math_ops.matmul(u_s_inv, term2_nous) 919 920 grad_a_before_transpose = term1 + term2 921 922 if a.dtype.is_complex: 923 eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) 924 l = eye * v_gv 925 term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l) 926 term3 = 1 / 2. * math_ops.matmul( 927 u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) 928 929 grad_a_before_transpose += term3 930 931 if use_adjoint: 932 grad_a = array_ops.matrix_transpose( 933 grad_a_before_transpose, conjugate=True) 934 else: 935 grad_a = grad_a_before_transpose 936 937 grad_a.set_shape(a_shape) 938 return grad_a 939 940 941def _LeftShift(x): 942 """Shifts next-to-last dimension to the left, adding zero on the right.""" 943 rank = array_ops.rank(x) 944 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 945 pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0) 946 return array_ops.pad(x[..., 1:, :], pad) 947 948 949def _RightShift(x): 950 """Shifts next-to-last dimension to the right, adding zero on the left.""" 951 rank = array_ops.rank(x) 952 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 953 pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0) 954 return array_ops.pad(x[..., :-1, :], pad) 955 956 957@ops.RegisterGradient("TridiagonalMatMul") 958def _TridiagonalMatMulGrad(op, grad): 959 """Gradient for TridiagonalMatMul.""" 960 superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True) 961 maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True) 962 subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True) 963 rhs_conj = math_ops.conj(op.inputs[3]) 964 965 superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1) 966 maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1) 967 subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1) 968 rhs_grad = _RightShift(superdiag_conj * grad) + \ 969 maindiag_conj * grad + _LeftShift(subdiag_conj * grad) 970 971 superdiag_grad = array_ops.expand_dims(superdiag_grad, -2) 972 maindiag_grad = array_ops.expand_dims(maindiag_grad, -2) 973 subdiag_grad = array_ops.expand_dims(subdiag_grad, -2) 974 975 return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad 976 977 978@ops.RegisterGradient("TridiagonalSolve") 979def _TridiagonalSolveGrad(op, grad): 980 """Gradient for TridiagonalSolveGrad.""" 981 diags = op.inputs[0] 982 x = op.outputs[0] 983 partial_pivoting = op.get_attr("partial_pivoting") 984 985 # Transposing the matrix within tridiagonal_solve kernel by interchanging 986 # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with 987 # paddings required by cusparse*gtsv routines. 988 # So constructing the transposed matrix in Python. 989 diags_transposed = _TransposeTridiagonalMatrix(diags) 990 991 grad_rhs = linalg_ops.tridiagonal_solve(diags_transposed, grad, 992 partial_pivoting=partial_pivoting) 993 grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x) 994 return grad_diags, grad_rhs 995 996 997def _TransposeTridiagonalMatrix(diags): 998 """Transposes a tridiagonal matrix. 999 1000 Args: 1001 diags: the diagonals of the input matrix in the compact form (see 1002 linalg_ops.tridiagonal_solve). 1003 1004 Returns: 1005 Diagonals of the transposed matrix in the compact form. 1006 """ 1007 1008 diag = diags[..., 1, :] 1009 1010 if diags.shape.is_fully_defined(): 1011 # For fully defined tensor we can concat with a tensor of zeros, which is 1012 # faster than using array_ops.pad(). 1013 zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype) 1014 superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1) 1015 subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1) 1016 else: 1017 rank = array_ops.rank(diags) 1018 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 1019 superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])), 1020 axis=0) 1021 superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad) 1022 subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])), 1023 axis=0) 1024 subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad) 1025 return array_ops.stack([superdiag, diag, subdiag], axis=-2) 1026 1027 1028def _MatmulExtractingThreeDiagonals(x, y_tr): 1029 """Multiplies matrices and extracts three diagonals from the product. 1030 1031 With sizes M x K and K x M, this function takes O(MK) time and O(M) space, 1032 while using math_ops.matmul, and then extracting the diagonals would take 1033 O(M^2 K) time and O(M^2) space. 1034 1035 Args: 1036 x: first matrix 1037 y_tr: second matrix transposed 1038 1039 Returns: 1040 Diagonals of the product in compact format (see 1041 linalg_ops.tridiagonal_solve) 1042 1043 """ 1044 diag = math_ops.reduce_sum(x * y_tr, axis=-1) 1045 1046 if y_tr.shape.is_fully_defined(): 1047 zeros = array_ops.zeros( 1048 list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype) 1049 superdiag = math_ops.reduce_sum( 1050 x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1) 1051 subdiag = math_ops.reduce_sum( 1052 x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1) 1053 else: 1054 rank = array_ops.rank(y_tr) 1055 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 1056 superdiag_pad = array_ops.concat( 1057 (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0) 1058 superdiag = math_ops.reduce_sum( 1059 x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1) 1060 subdiag_pad = array_ops.concat( 1061 (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0) 1062 subdiag = math_ops.reduce_sum( 1063 x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1) 1064 return array_ops.stack([superdiag, diag, subdiag], axis=-2) 1065