1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in linalg_ops.py. 16 17Useful reference for derivative formulas is (Mike Giles, 2008). 18 19Ionescu et al. (2015) provide a detailed derivation of formulas for 20backpropagating through spectral layers (SVD and Eig). 21 22References: 23 An extended collection of matrix derivative results for 24 forward and reverse mode automatic differentiation: 25 [Mike Giles, 2008] 26 (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124) 27 ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf)) 28 Matrix Backpropagation for Deep Networks with Structured Layers 29 [Ionescu et al., 2015] 30 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html) 31 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf)) 32 Training Deep Networks with Structured Layers by Matrix Backpropagation: 33 [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838) 34 ([pdf](https://arxiv.org/pdf/1509.07838.pdf)) 35""" 36from __future__ import absolute_import 37from __future__ import division 38from __future__ import print_function 39 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import ops 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gen_linalg_ops 45from tensorflow.python.ops import linalg_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops.linalg import linalg_impl as _linalg 48 49 50@ops.RegisterGradient("MatrixInverse") 51def _MatrixInverseGrad(op, grad): 52 """Gradient for MatrixInverse.""" 53 ainv = op.outputs[0] 54 return -math_ops.matmul( # pylint: disable=invalid-unary-operand-type 55 ainv, 56 math_ops.matmul(grad, ainv, adjoint_b=True), 57 adjoint_a=True) 58 59 60@ops.RegisterGradient("Einsum") 61def _EinsumGrad(op, grad): 62 """Gradient for Einsum.""" 63 ellipsis = "..." 64 65 def _GetAxisFromLabel(subscripts, label): 66 """Returns the axis (possibly negative) corresponding to a label. 67 68 Returns the axis index of the axis label if it is before an ellipsis (or if 69 the ellipsis is not present), and the negative index if it occurs after the 70 ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`. 71 72 For multiple occurrences, returns the leftmost one. If not found, returns 73 None. 74 75 Args: 76 subscripts: A string denoting the einsum subscript (e.g. `ab...cd`) 77 label: The single character axis label. 78 """ 79 splits = subscripts.split(ellipsis) 80 index = splits[0].find(label) 81 if index != -1: 82 return index 83 if len(splits) < 2: 84 return None 85 index = splits[1].find(label) 86 if index != -1: 87 return index - len(splits[1]) 88 return None 89 90 def _GetBcastSubshape(subscripts): 91 """Returns a tuple denoting the slice mapping to ellipsis. 92 93 For a given subscript, returns a tuple (start, end) denoting the start 94 axis index and the (negative) end axis index respectively. For any input 95 Tensor `x` described by the subscript, `x[start:end]` would be the slice 96 represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`. 97 98 If ellipsis is not present in `subscripts`, returns `(0, 0)`. 99 100 Args: 101 subscripts: A string denoting the einsum subscript. 102 """ 103 start = subscripts.find(ellipsis) 104 if start == -1: 105 return 0, 0 106 remaining = len(subscripts) - (start + len(ellipsis)) 107 end = -remaining if remaining > 0 else None 108 return start, end 109 110 def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts): 111 """Returns reduced subscripts and their corresponding dimensions and axes. 112 113 Given a set of axis labels, returns their concatenated subscript, their 114 corresponding dimensions from input_shape, and their corresponding axes. 115 Note that the concatenated subscript `reduced_subs` may have axis labels 116 from `reduced_label_set` in any order. For example, for the reduced label 117 set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns 118 subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`. 119 120 Args: 121 reduced_label_set: Set of axis labels which appear in `subscripts`. 122 input_shape: A `Tensor` representing the shape of the einsum operand 123 corresponding to `subscripts`. 124 subscripts: A string denoting the einsum subscript. 125 126 Returns: 127 reduced_subs: Subscripts formed by a concatenation of labels in 128 `reduced_label_set`. 129 reduced_dims: Dimensions from `input_shape` corresponding to each label 130 in `reduced_subs`. 131 reduced_axes: Axes described by `subscripts` corresponding to each label 132 in `reduced_subs`. If there are multiple occurrences in `subscripts`, 133 we consider only the leftmost one. 134 135 """ 136 # Concatenate the sequence of reduced axis labels. 137 reduced_subs = "".join(list(reduced_label_set)) 138 # Get the axis (may be positive, negative or zero) for each of the reduced 139 # labels. If the same label appears multiple times, get the left-most axis. 140 reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs] 141 # Get the corresponding dimensions for each reduced axis. 142 reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes]) 143 return reduced_subs, reduced_dims, reduced_axes 144 145 def _GetGradReduced(output_grad, output_subs, input_subs, input_shape, 146 reduced_label_set): 147 """Returns the gradient wrt input for a unary einsum with reductions. 148 149 Args: 150 output_grad: The gradient wrt the output of a unary einsum operation. 151 output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). 152 input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). 153 input_shape: A `Tensor` representing the shape of the input operand. 154 reduced_label_set: The set of axis labels appearing in `input_subs` but 155 not in `output_subs`. 156 """ 157 # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and 158 # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced 159 # subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. 160 reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts( 161 reduced_label_set, input_shape, input_subs) 162 # Whether either the input or the output subscripts have a repeated label. 163 # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". 164 has_repeated_labels = ( 165 len(set(input_subs)) + len(set(output_subs)) < 166 len(input_subs) + len(output_subs)) 167 # Compute the input subscripts without the reduced axis labels, e.g. "aac" 168 # for the equation "aabbcd->ca". 169 input_subs_without_reduced_labels = "".join( 170 [s for s in input_subs if s not in reduced_label_set]) 171 172 # The gradient wrt the input for the equation "abc->ac" (or, equivalently 173 # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times 174 # along axis 1, where label 'b' represents a dimension of size N. 175 # 176 # If we're not dealing with repeated labels, and the non-reduced labels 177 # doesn't need to be transposed, then just tiling is enough and there is no 178 # need to call another einsum. For example, tiling is sufficient for 179 # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or 180 # "abc->ca" (transpose), we'd need another einsum operation after tiling. 181 if (not has_repeated_labels and 182 input_subs_without_reduced_labels == output_subs): 183 # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. 184 # for the equation "abcd->ac" with input shape [2,5,3,4], we get the 185 # reduced shape [2,1,3,1]. 186 reduced_shape = math_ops.reduced_shape( 187 input_shape, ops.convert_to_tensor(reduced_axes)) 188 # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to 189 # the shape [2,5,3,4] results in the gradient wrt "abcd". 190 return array_ops.broadcast_to( 191 array_ops.reshape(output_grad, reduced_shape), input_shape) 192 193 # If we *do* have traces or transpose operations, then prepend the extra 194 # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd 195 # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". 196 # 197 # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. 198 # This is the shape of the intermediate "bdca". 199 grad_shape_with_reduced_labels = array_ops.concat( 200 [reduced_dims, array_ops.shape(output_grad)], axis=0) 201 # Obtain the output shape of the reduction-only equation "bdca->ca" as if 202 # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we 203 # just have to prepend that many 1s to the output shape. 204 reduced_shape = ( 205 array_ops.concat([ 206 array_ops.ones(len(reduced_label_set), dtype=dtypes.int32), 207 array_ops.shape(output_grad) 208 ], 209 axis=0)) 210 # Compute the VJP for the intermediate (viz. "bdca->ca") for which 211 # broadcasting is sufficient. 212 broadcasted_grad = array_ops.broadcast_to( 213 array_ops.reshape(output_grad, reduced_shape), 214 grad_shape_with_reduced_labels) 215 # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use 216 # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd") 217 # since the output axis labels now appear in the input subscripts. 218 return gen_linalg_ops.einsum([broadcasted_grad], 219 "{}->{}".format(reduced_subs + output_subs, 220 input_subs)) 221 222 def _GetGradWrt(output_grad, other_operand, input_shape, input_subs, 223 other_subs, output_subs): 224 """Returns the gradient wrt an input operand for a binary einsum. 225 226 This function does not handle (un)broadcasting. This must be done separately 227 on the returned gradient. 228 229 Args: 230 output_grad: The gradient wrt the output of a binary einsum operation. 231 other_operand: The complementary `Tensor` operand i.e. which is not the 232 input operand. 233 input_shape: A `Tensor` representing the shape of input operand. 234 input_subs: The subscripts of the input operand. 235 other_subs: The subscripts of the complementary operand. 236 output_subs: The output subscripts. 237 """ 238 # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y), 239 # where the equation involves only Tensor contractions, generalized traces 240 # and transposes, the input gradients are given by the vector-jacobian 241 # products (VJPs): 242 # 243 # grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z) 244 # grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z} 245 # 246 # where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs 247 # x and y and grad_wrt_z is the given gradient with respect to output z. 248 # 249 # Proof: For unary einsum equations involving only transpose ("ij->ji") and 250 # traces ("ii->i"), the linear mapping's Jacobian at input x is given 251 # by the function itself. We can verify that the linear map given by the 252 # VJP are einsums with the equations "ji->ij" and "i->ii" respectively, 253 # where the latter represents 'un-tracing', or filling the diagonal with 254 # the input axis and non-diagonal entries are zeros. 255 # Furthermore, recall that matrix multiplication, which is 256 # represented by the equation "ab,bc->ac", has its VJPs given by the 257 # einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example 258 # https://math.stackexchange.com/a/2755680). Combined with transposes and 259 # traces we can rewrite Tensor contractions as regular matrix 260 # multiplication. Since each of these operations have their VJPs described 261 # by einsums of the required pattern, the result follows. 262 # 263 # Accordingly, einsum operations except for those with reductions, e.g. 264 # "abc,cd->ad" have their VJPs defined by: 265 # "{output_subs},{other_subs}->{input_subs}". 266 # 267 # But if there is a reduction, this would lead to the equation "ad,cd->abc" 268 # which is invalid because the reduced axis label 'b' is present in the 269 # output but not in any of the inputs. Therefore, we compute the VJP in two 270 # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of 271 # "abc->ac" or, equivalently, reduce_sum(..., axis=1). 272 # 273 # Compute the set of input axis labels which doesn't appear in either the 274 # output subscripts or the other operand's subscript. E.g. the set {'b'} for 275 # the equation "abc,cd->ad". 276 reduced_label_set = set(input_subs).difference( 277 set(output_subs + other_subs + ".")) 278 # Obtain the input subscripts with the reduced axis labels removed. E.g. 279 # "ac" in the above example. 280 left_subs = "".join(s for s in input_subs if s not in reduced_label_set) 281 282 # Compute the gradient wrt the input, without accounting for the operation 283 # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad". 284 grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand], 285 "{},{}->{}".format( 286 output_subs, other_subs, 287 left_subs)) 288 # If the reduced_label_set is empty, then we already have the gradient 289 # wrt the input. 290 if not reduced_label_set: 291 return grad_reduced 292 # Otherwise, we currently have the gradient wrt the output of the reduction 293 # operation "abc->ac". Invoke the subroutine for the gradient for unary 294 # einsum with reductions. 295 return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, 296 reduced_label_set) 297 298 equation = op.get_attr("equation") 299 if isinstance(equation, bytes): 300 equation = equation.decode() 301 input_subs, output_subs = equation.split("->") 302 303 if len(op.inputs) == 1: 304 # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the 305 # input (VJP) is given by the reversed equation: 306 # grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z) 307 # (See the justification in _GetGradWrt). This is valid unless there are 308 # reduced axis labels; i.e. axis labels appearing in the input but not in 309 # the output subscripts. 310 input_shape = array_ops.shape(op.inputs[0]) 311 # Find the axis labels which appear only in the input. 312 reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis)) 313 if not reduced_label_set: 314 # Return the einsum given by the reversed equation, since we don't have 315 # reduced axes. 316 return gen_linalg_ops.einsum([grad], 317 "{}->{}".format(output_subs, input_subs)) 318 # We do have reduced axes, so we invoke the subroutine for reduced unary 319 # einsums. 320 return _GetGradReduced(grad, output_subs, input_subs, input_shape, 321 reduced_label_set) 322 323 x_subs, y_subs = input_subs.split(",") 324 # Add ellipsis for broadcasted dimensions if any operand does not have it. 325 # This is because the equation "...ij,jk->ik" may be valid if the 0th input's 326 # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid 327 # because only the output subscripts contain ellipsis. 328 if ellipsis in output_subs: 329 if ellipsis not in x_subs: 330 x_subs += ellipsis 331 if ellipsis not in y_subs: 332 y_subs += ellipsis 333 334 # Obtain the gradients wrt the inputs x and y, without taking into account 335 # the unbroadcasting. 336 x, y = op.inputs[0], op.inputs[1] 337 if grad.dtype.is_complex: 338 x = math_ops.conj(x) 339 y = math_ops.conj(y) 340 341 x_shape = array_ops.shape(x) 342 y_shape = array_ops.shape(y) 343 grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs) 344 grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs) 345 346 if ellipsis not in output_subs: 347 # If no ellipsis in the output; then no need to unbroadcast. 348 return grad_x, grad_y 349 350 # Below we handle the case that broadcasting between x and y was necessary, 351 # with x and y having possibly different batch shapes. 352 353 # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c' 354 # and shape of rank 10; the range [3:-1] denotes the broadcasted axes. 355 bx_start, bx_end = _GetBcastSubshape(x_subs) 356 by_start, by_end = _GetBcastSubshape(y_subs) 357 # If the static batch shapes are equal, we don't need to unbroadcast. 358 x_shape_static = x.get_shape() 359 y_shape_static = y.get_shape() 360 if (x_shape_static.is_fully_defined() and 361 y_shape_static.is_fully_defined() and 362 x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]): 363 return grad_x, grad_y 364 365 # Sum the gradient across the broadcasted axes. 366 rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end], 367 y_shape[by_start:by_end]) 368 grad_x = array_ops.reshape( 369 math_ops.reduce_sum(grad_x, bx_start + rx), x_shape) 370 grad_y = array_ops.reshape( 371 math_ops.reduce_sum(grad_y, by_start + ry), y_shape) 372 return grad_x, grad_y 373 374 375@ops.RegisterGradient("MatrixDeterminant") 376def _MatrixDeterminantGrad(op, grad): 377 """Gradient for MatrixDeterminant.""" 378 a = op.inputs[0] 379 c = op.outputs[0] 380 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 381 multipliers = array_ops.reshape(grad * c, 382 array_ops.concat([array_ops.shape(c), [1, 1]], 383 0)) 384 return multipliers * a_adj_inv 385 386 387@ops.RegisterGradient("MatrixSquareRoot") 388def _MatrixSquareRootGrad(op, grad): 389 """Gradient for MatrixSquareRoot.""" 390 391 # Let A be an m x m square matrix (or batch of matrices) 392 # Let R = sqrtm(A) 393 # By definition, A = RR 394 # Take the differential: dA = d(RR) = RdR + dRR 395 # Solve the resulting Sylvester equation for dR 396 397 # Used to find Kronecker products within the Sylvester equation 398 def _KroneckerProduct(b1, b2): 399 """Computes the Kronecker product of two batches of square matrices.""" 400 b1_shape = array_ops.shape(b1) 401 b2_shape = array_ops.shape(b2) 402 b1_order = b1_shape[-1] 403 b2_order = b2_shape[-1] 404 405 shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)] 406 shape_slice = array_ops.slice(b1_shape, [0], 407 shape_slice_size) # Same for both batches 408 b1_reshape_shape = array_ops.concat( 409 [shape_slice, [b1_order], [1], [b1_order], [1]], 0) 410 b2_reshape_shape = array_ops.concat( 411 [shape_slice, [1], [b2_order], [1], [b2_order]], 0) 412 413 b1_reshape = array_ops.reshape(b1, b1_reshape_shape) 414 b2_reshape = array_ops.reshape(b2, b2_reshape_shape) 415 416 order_prod = b1_order * b2_order 417 kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0) 418 return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape) 419 420 sqrtm = op.outputs[0] # R 421 shape = array_ops.shape(sqrtm) 422 order = shape[-1] # m 423 matrix_count = math_ops.reduce_prod(shape[0:-2]) 424 425 # Get batch of m x m identity matrices 426 eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix 427 eye_flat = array_ops.reshape(eye, [-1]) 428 eye_tiled = array_ops.tile(eye_flat, [matrix_count]) 429 eye_batch = array_ops.reshape(eye_tiled, shape) 430 431 # The transpose of R is taken in the k1 term instead of k2 in 432 # order to prevent redundant transposition of R (i.e. (R')' = R) 433 sqrtm_transpose = array_ops.matrix_transpose(sqrtm) 434 k1 = _KroneckerProduct(eye_batch, sqrtm_transpose) 435 k2 = _KroneckerProduct(sqrtm, eye_batch) 436 ksum = math_ops.add(k1, k2) 437 438 # Vectorize dA 439 shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)] 440 shape_slice = array_ops.slice(shape, [0], shape_slice_size) 441 shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0) 442 vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da) 443 444 # Solve for vec(dR) 445 vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da) 446 447 # Solve for dR by inverse vectorizing vec(dR) 448 dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape) 449 return array_ops.matrix_transpose(dsqrtm_transpose) 450 451 452@ops.RegisterGradient("LogMatrixDeterminant") 453def _LogMatrixDeterminantGrad(op, _, grad_b): 454 """Gradient for LogMatrixDeterminant.""" 455 a = op.inputs[0] 456 c = op.outputs[1] 457 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 458 multipliers = array_ops.reshape( 459 grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) 460 return multipliers * a_adj_inv 461 462 463@ops.RegisterGradient("Cholesky") 464def _CholeskyGrad(op, grad): 465 """Gradient for Cholesky.""" 466 467 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} 468 l = op.outputs[0] 469 num_rows = array_ops.shape(l)[-1] 470 batch_shape = array_ops.shape(l)[:-2] 471 l_inverse = linalg_ops.matrix_triangular_solve(l, 472 linalg_ops.eye( 473 num_rows, 474 batch_shape=batch_shape, 475 dtype=l.dtype)) 476 477 middle = math_ops.matmul(l, grad, adjoint_a=True) 478 middle = array_ops.matrix_set_diag(middle, 479 0.5 * array_ops.matrix_diag_part(middle)) 480 middle = array_ops.matrix_band_part(middle, -1, 0) 481 482 grad_a = math_ops.matmul( 483 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) 484 485 grad_a += _linalg.adjoint(grad_a) 486 return grad_a * 0.5 487 488 489@ops.RegisterGradient("Qr") 490def _QrGrad(op, dq, dr): 491 """Gradient for Qr.""" 492 493 # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 494 # QR and LQ Decomposition Matrix Backpropagation Algorithms for 495 # Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation 496 q, r = op.outputs 497 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or 498 r.shape.as_list()[-1] is None): 499 raise NotImplementedError("QrGrad not implemented with dynamic shapes.") 500 if (r.shape.dims[-2].value > r.shape.dims[-1].value and 501 q.shape.dims[-2].value == q.shape.dims[-1].value): 502 raise NotImplementedError("QrGrad not implemented when nrows > ncols " 503 "and full_matrices is true.") 504 505 def _TriangularSolve(x, r): 506 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" 507 return _linalg.adjoint( 508 linalg_ops.matrix_triangular_solve( 509 r, _linalg.adjoint(x), lower=False, adjoint=False)) 510 511 def _QrGradSquareAndDeepMatrices(q, r, dq, dr): 512 """Gradient for matrix orders num_rows >= num_cols 513 and full_matrices is false. 514 """ 515 qdq = math_ops.matmul(q, dq, adjoint_a=True) 516 qdq_ = qdq - _linalg.adjoint(qdq) 517 rdr = math_ops.matmul(r, dr, adjoint_b=True) 518 rdr_ = rdr - _linalg.adjoint(rdr) 519 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) 520 521 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) 522 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) 523 ret = grad_a + grad_b 524 525 if q.dtype.is_complex: 526 # need to add a correction to the gradient formula for complex case 527 m = rdr - _linalg.adjoint(qdq) 528 eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m)) 529 correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype) 530 ret = ret + _TriangularSolve( 531 math_ops.matmul(q, _linalg.adjoint(correction)), r) 532 533 return ret 534 535 num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1] 536 537 if num_rows >= num_cols: 538 return _QrGradSquareAndDeepMatrices(q, r, dq, dr) 539 540 # Partition a = [x, y], r = [u, v] and reduce to the square case 541 a = op.inputs[0] 542 y = a[..., :, num_rows:] 543 u = r[..., :, :num_rows] 544 dv = dr[..., :, num_rows:] 545 du = dr[..., :, :num_rows] 546 dy = math_ops.matmul(q, dv) 547 dx = _QrGradSquareAndDeepMatrices(q, u, 548 dq + math_ops.matmul(y, dv, adjoint_b=True), 549 du) 550 return array_ops.concat([dx, dy], axis=-1) 551 552 553@ops.RegisterGradient("MatrixSolve") 554def _MatrixSolveGrad(op, grad): 555 """Gradient for MatrixSolve.""" 556 a = op.inputs[0] 557 adjoint_a = op.get_attr("adjoint") 558 c = op.outputs[0] 559 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) 560 if adjoint_a: 561 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 562 else: 563 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 564 return (grad_a, grad_b) 565 566 567@ops.RegisterGradient("MatrixSolveLs") 568def _MatrixSolveLsGrad(op, grad): 569 """Gradients for MatrixSolveLs.""" 570 571 # TODO(rmlarsen): The implementation could be more efficient: 572 # a) Output the Cholesky factorization from forward op instead of 573 # recomputing it here. 574 # b) Implement a symmetric rank-k update op instead of computing 575 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. 576 577 def _Overdetermined(op, grad): 578 """Gradients for the overdetermined case of MatrixSolveLs. 579 580 This is the backprop for the solution to the normal equations of the first 581 kind: 582 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B 583 which solve the least squares problem 584 min ||A * X - B||_F^2 + lambda ||X||_F^2. 585 """ 586 a = op.inputs[0] 587 b = op.inputs[1] 588 x = op.outputs[0] 589 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 590 # pylint: disable=protected-access 591 chol = linalg_ops._RegularizedGramianCholesky( 592 a, l2_regularizer=l2_regularizer, first_kind=True) 593 # pylint: enable=protected-access 594 # Temporary z = (A^T * A + lambda * I)^{-1} * grad. 595 z = linalg_ops.cholesky_solve(chol, grad) 596 xzt = math_ops.matmul(x, z, adjoint_b=True) 597 zx_sym = xzt + array_ops.matrix_transpose(xzt) 598 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 599 grad_b = math_ops.matmul(a, z) 600 return (grad_a, grad_b, None) 601 602 def _Underdetermined(op, grad): 603 """Gradients for the underdetermined case of MatrixSolveLs. 604 605 This is the backprop for the solution to the normal equations of the second 606 kind: 607 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B 608 that (for lambda=0) solve the least squares problem 609 min ||X||_F subject to A*X = B. 610 """ 611 a = op.inputs[0] 612 b = op.inputs[1] 613 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 614 # pylint: disable=protected-access 615 chol = linalg_ops._RegularizedGramianCholesky( 616 a, l2_regularizer=l2_regularizer, first_kind=False) 617 # pylint: enable=protected-access 618 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) 619 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. 620 tmp = linalg_ops.cholesky_solve(chol, b) 621 a1 = math_ops.matmul(tmp, a, adjoint_a=True) 622 a1 = -math_ops.matmul(grad_b, a1) # pylint: disable=invalid-unary-operand-type 623 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) 624 a2 = math_ops.matmul(tmp, a2, adjoint_b=True) 625 grad_a = a1 + a2 626 return (grad_a, grad_b, None) 627 628 fast = op.get_attr("fast") 629 if fast is False: 630 raise ValueError("Gradient not defined for fast=False") 631 matrix_shape = op.inputs[0].get_shape()[-2:] 632 if matrix_shape.is_fully_defined(): 633 if matrix_shape[-2] >= matrix_shape[-1]: 634 return _Overdetermined(op, grad) 635 else: 636 return _Underdetermined(op, grad) 637 else: 638 # We have to defer determining the shape to runtime and use 639 # conditional execution of the appropriate graph. 640 matrix_shape = array_ops.shape(op.inputs[0])[-2:] 641 return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], 642 lambda: _Overdetermined(op, grad), 643 lambda: _Underdetermined(op, grad)) 644 645 646@ops.RegisterGradient("BandedTriangularSolve") 647def _BandedTriangularSolveGrad(op, grad): 648 """Gradient for BandedTriangularSolve.""" 649 a = op.inputs[0] 650 b = op.inputs[1] 651 num_bands = array_ops.shape(a)[-2] 652 adjoint_a = op.get_attr("adjoint") 653 lower_a = op.get_attr("lower") 654 c = op.outputs[0] 655 grad_b = linalg_ops.banded_triangular_solve( 656 a, grad, lower=lower_a, adjoint=not adjoint_a) 657 if adjoint_a: 658 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 659 else: 660 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 661 if lower_a: 662 grad_a = array_ops.matrix_diag_part( 663 grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT") 664 else: 665 grad_a = array_ops.matrix_diag_part( 666 grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT") 667 # If the static batch shapes are equal, we don't need to unbroadcast. 668 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 669 a.shape[:-2] == b.shape[:-2]): 670 return grad_a, grad_b 671 a_shape = array_ops.shape(a) 672 b_shape = array_ops.shape(b) 673 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 674 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 675 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 676 return grad_a, grad_b 677 678 679@ops.RegisterGradient("MatrixTriangularSolve") 680def _MatrixTriangularSolveGrad(op, grad): 681 """Gradient for MatrixTriangularSolve.""" 682 a = op.inputs[0] 683 b = op.inputs[1] 684 adjoint_a = op.get_attr("adjoint") 685 lower_a = op.get_attr("lower") 686 c = op.outputs[0] 687 grad_b = linalg_ops.matrix_triangular_solve( 688 a, grad, lower=lower_a, adjoint=not adjoint_a) 689 if adjoint_a: 690 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 691 else: 692 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 693 if lower_a: 694 grad_a = array_ops.matrix_band_part(grad_a, -1, 0) 695 else: 696 grad_a = array_ops.matrix_band_part(grad_a, 0, -1) 697 # If the static batch shapes are equal, we don't need to unbroadcast. 698 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 699 a.shape[:-2] == b.shape[:-2]): 700 return grad_a, grad_b 701 a_shape = array_ops.shape(a) 702 b_shape = array_ops.shape(b) 703 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 704 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 705 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 706 return grad_a, grad_b 707 708 709# To avoid nan in cases with degenerate eigenvalues or 710# degenerate/zero singular values in calculations of 711# f and s_inv_mat, we introduce a Lorentz broadening. 712def _SafeReciprocal(x, epsilon=1E-20): 713 return x * math_ops.reciprocal(x * x + epsilon) 714 715 716@ops.RegisterGradient("Eig") 717def _EigGrad(op, grad_e, grad_v): 718 """Gradient for Eig. 719 720 Based on eq. 4.77 from paper by 721 Christoph Boeddeker et al. 722 https://arxiv.org/abs/1701.00392 723 See also 724 "Computation of eigenvalue and eigenvector derivatives 725 for a general complex-valued eigensystem" by Nico van der Aa. 726 As for now only distinct eigenvalue case is considered. 727 """ 728 e = op.outputs[0] 729 compute_v = op.get_attr("compute_v") 730 # a = op.inputs[0], which satisfies 731 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 732 with ops.control_dependencies([grad_e, grad_v]): 733 if compute_v: 734 v = op.outputs[1] 735 vt = _linalg.adjoint(v) 736 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 737 # Notice that because of the term involving f, the gradient becomes 738 # infinite (or NaN in practice) when eigenvalues are not unique. 739 # Mathematically this should not be surprising, since for (k-fold) 740 # degenerate eigenvalues, the corresponding eigenvectors are only defined 741 # up to arbitrary rotation in a (k-dimensional) subspace. 742 f = array_ops.matrix_set_diag( 743 _SafeReciprocal( 744 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 745 array_ops.zeros_like(e)) 746 f = math_ops.conj(f) 747 vgv = math_ops.matmul(vt, grad_v) 748 mid = array_ops.matrix_diag(grad_e) 749 diag_grad_part = array_ops.matrix_diag( 750 array_ops.matrix_diag_part( 751 math_ops.cast(math_ops.real(vgv), vgv.dtype))) 752 mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part)) 753 # vt is formally invertible as long as the original matrix is 754 # diagonalizable. However, in practice, vt may 755 # be ill-conditioned when matrix original matrix is close to 756 # non-diagonalizable one 757 grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt)) 758 else: 759 _, v = linalg_ops.eig(op.inputs[0]) 760 vt = _linalg.adjoint(v) 761 # vt is formally invertible as long as the original matrix is 762 # diagonalizable. However, in practice, vt may 763 # be ill-conditioned when matrix original matrix is close to 764 # non-diagonalizable one 765 grad_a = linalg_ops.matrix_solve( 766 vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt)) 767 return math_ops.cast(grad_a, op.inputs[0].dtype) 768 769 770@ops.RegisterGradient("SelfAdjointEigV2") 771def _SelfAdjointEigV2Grad(op, grad_e, grad_v): 772 """Gradient for SelfAdjointEigV2.""" 773 e = op.outputs[0] 774 compute_v = op.get_attr("compute_v") 775 # a = op.inputs[0], which satisfies 776 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 777 with ops.control_dependencies([grad_e, grad_v]): 778 if compute_v: 779 v = op.outputs[1] 780 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 781 # Notice that because of the term involving f, the gradient becomes 782 # infinite (or NaN in practice) when eigenvalues are not unique. 783 # Mathematically this should not be surprising, since for (k-fold) 784 # degenerate eigenvalues, the corresponding eigenvectors are only defined 785 # up to arbitrary rotation in a (k-dimensional) subspace. 786 f = array_ops.matrix_set_diag( 787 _SafeReciprocal( 788 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 789 array_ops.zeros_like(e)) 790 grad_a = math_ops.matmul( 791 v, 792 math_ops.matmul( 793 array_ops.matrix_diag(grad_e) + 794 f * math_ops.matmul(v, grad_v, adjoint_a=True), 795 v, 796 adjoint_b=True)) 797 else: 798 _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) 799 grad_a = math_ops.matmul(v, 800 math_ops.matmul( 801 array_ops.matrix_diag(grad_e), 802 v, 803 adjoint_b=True)) 804 # The forward op only depends on the lower triangular part of a, so here we 805 # symmetrize and take the lower triangle 806 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) 807 grad_a = array_ops.matrix_set_diag(grad_a, 808 0.5 * array_ops.matrix_diag_part(grad_a)) 809 return grad_a 810 811 812@ops.RegisterGradient("Svd") 813def _SvdGrad(op, grad_s, grad_u, grad_v): 814 """Gradient for the singular value decomposition.""" 815 816 # The derivation for the compute_uv=False case, and most of 817 # the derivation for the full_matrices=True case, are in 818 # Giles' paper (see reference at top of file). A derivation for 819 # the full_matrices=False case is available at 820 # https://j-towns.github.io/papers/svd-derivative.pdf 821 # The derivation for complex valued SVD can be found in 822 # https://re-ra.xyz/misc/complexsvd.pdf or 823 # https://giggleliu.github.io/2019/04/02/einsumbp.html 824 a = op.inputs[0] 825 a_shape = a.get_shape().with_rank_at_least(2) 826 grad_s = math_ops.cast(grad_s, a.dtype) 827 grad_s_mat = array_ops.matrix_diag(grad_s) 828 829 if not op.get_attr("compute_uv"): 830 s, u, v = linalg_ops.svd(a, compute_uv=True) 831 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) 832 grad_a.set_shape(a_shape) 833 return grad_a 834 835 full_matrices = op.get_attr("full_matrices") 836 837 grad_u_shape = grad_u.get_shape().with_rank_at_least(2) 838 grad_v_shape = grad_v.get_shape().with_rank_at_least(2) 839 m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) 840 n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) 841 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( 842 grad_v_shape[:-2]) 843 a_shape = batch_shape.concatenate([m, n]) 844 845 m = a_shape.dims[-2].value 846 n = a_shape.dims[-1].value 847 # TODO(rmlarsen): Make this work with placeholders. 848 if m is None or n is None: 849 raise NotImplementedError( 850 "SVD gradient has not been implemented for input with unknown " 851 "inner matrix shape.") 852 853 s = op.outputs[0] 854 u = op.outputs[1] 855 v = op.outputs[2] 856 s = math_ops.cast(s, a.dtype) 857 858 use_adjoint = False 859 if m > n: 860 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the 861 # Hermitian transpose of the gradient at the end. 862 use_adjoint = True 863 m, n = n, m 864 u, v = v, u 865 grad_u, grad_v = grad_v, grad_u 866 867 with ops.control_dependencies([grad_s, grad_u, grad_v]): 868 if full_matrices and abs(m - n) > 1: 869 raise NotImplementedError( 870 "svd gradient is not implemented for abs(m - n) > 1 " 871 "when full_matrices is True") 872 s_mat = array_ops.matrix_diag(s) 873 s2 = math_ops.square(s) 874 875 # NOTICE: Because of the term involving f, the gradient becomes 876 # infinite (or NaN in practice) when singular values are not unique. 877 # Mathematically this should not be surprising, since for (k-fold) 878 # degenerate singular values, the corresponding singular vectors are 879 # only defined up a (k-dimensional) subspace. In practice, this can 880 # lead to numerical instability when singular values are close but not 881 # exactly equal. 882 883 s_shape = array_ops.shape(s) 884 f = array_ops.matrix_set_diag( 885 _SafeReciprocal( 886 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), 887 array_ops.zeros_like(s)) 888 s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) 889 890 v1 = v[..., :, :m] 891 grad_v1 = grad_v[..., :, :m] 892 893 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) 894 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) 895 896 f_u = f * u_gu 897 f_v = f * v_gv 898 899 term1_nouv = ( 900 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + 901 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) 902 903 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) 904 905 if m == n: 906 grad_a_before_transpose = term1 907 else: 908 gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) 909 gv1t_v1 = math_ops.matmul(gv1t, v1) 910 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) 911 912 if full_matrices: 913 v2 = v[..., :, m:n] 914 grad_v2 = grad_v[..., :, m:n] 915 916 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) 917 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) 918 919 u_s_inv = math_ops.matmul(u, s_inv_mat) 920 term2 = math_ops.matmul(u_s_inv, term2_nous) 921 922 grad_a_before_transpose = term1 + term2 923 924 if a.dtype.is_complex: 925 eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) 926 l = eye * v_gv 927 term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l) 928 term3 = 1 / 2. * math_ops.matmul( 929 u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) 930 931 grad_a_before_transpose += term3 932 933 if use_adjoint: 934 grad_a = array_ops.matrix_transpose( 935 grad_a_before_transpose, conjugate=True) 936 else: 937 grad_a = grad_a_before_transpose 938 939 grad_a.set_shape(a_shape) 940 return grad_a 941 942 943def _LeftShift(x): 944 """Shifts next-to-last dimension to the left, adding zero on the right.""" 945 rank = array_ops.rank(x) 946 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 947 pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0) 948 return array_ops.pad(x[..., 1:, :], pad) 949 950 951def _RightShift(x): 952 """Shifts next-to-last dimension to the right, adding zero on the left.""" 953 rank = array_ops.rank(x) 954 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 955 pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0) 956 return array_ops.pad(x[..., :-1, :], pad) 957 958 959@ops.RegisterGradient("TridiagonalMatMul") 960def _TridiagonalMatMulGrad(op, grad): 961 """Gradient for TridiagonalMatMul.""" 962 superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True) 963 maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True) 964 subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True) 965 rhs_conj = math_ops.conj(op.inputs[3]) 966 967 superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1) 968 maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1) 969 subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1) 970 rhs_grad = _RightShift(superdiag_conj * grad) + \ 971 maindiag_conj * grad + _LeftShift(subdiag_conj * grad) 972 973 superdiag_grad = array_ops.expand_dims(superdiag_grad, -2) 974 maindiag_grad = array_ops.expand_dims(maindiag_grad, -2) 975 subdiag_grad = array_ops.expand_dims(subdiag_grad, -2) 976 977 return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad 978 979 980@ops.RegisterGradient("TridiagonalSolve") 981def _TridiagonalSolveGrad(op, grad): 982 """Gradient for TridiagonalSolveGrad.""" 983 diags = op.inputs[0] 984 x = op.outputs[0] 985 partial_pivoting = op.get_attr("partial_pivoting") 986 perturb_singular = op.get_attr("perturb_singular") 987 988 # Transposing the matrix within tridiagonal_solve kernel by interchanging 989 # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with 990 # paddings required by cusparse*gtsv routines. 991 # So constructing the transposed matrix in Python. 992 diags_transposed = _TransposeTridiagonalMatrix(diags) 993 994 grad_rhs = linalg_ops.tridiagonal_solve( 995 diags_transposed, 996 grad, 997 partial_pivoting=partial_pivoting, 998 perturb_singular=perturb_singular) 999 grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x) # pylint: disable=invalid-unary-operand-type 1000 return grad_diags, grad_rhs 1001 1002 1003def _TransposeTridiagonalMatrix(diags): 1004 """Transposes a tridiagonal matrix. 1005 1006 Args: 1007 diags: the diagonals of the input matrix in the compact form (see 1008 linalg_ops.tridiagonal_solve). 1009 1010 Returns: 1011 Diagonals of the transposed matrix in the compact form. 1012 """ 1013 1014 diag = diags[..., 1, :] 1015 1016 if diags.shape.is_fully_defined(): 1017 # For fully defined tensor we can concat with a tensor of zeros, which is 1018 # faster than using array_ops.pad(). 1019 zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype) 1020 superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1) 1021 subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1) 1022 else: 1023 rank = array_ops.rank(diags) 1024 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 1025 superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])), 1026 axis=0) 1027 superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad) 1028 subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])), 1029 axis=0) 1030 subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad) 1031 return array_ops.stack([superdiag, diag, subdiag], axis=-2) 1032 1033 1034def _MatmulExtractingThreeDiagonals(x, y_tr): 1035 """Multiplies matrices and extracts three diagonals from the product. 1036 1037 With sizes M x K and K x M, this function takes O(MK) time and O(M) space, 1038 while using math_ops.matmul, and then extracting the diagonals would take 1039 O(M^2 K) time and O(M^2) space. 1040 1041 Args: 1042 x: first matrix 1043 y_tr: second matrix transposed 1044 1045 Returns: 1046 Diagonals of the product in compact format (see 1047 linalg_ops.tridiagonal_solve) 1048 1049 """ 1050 diag = math_ops.reduce_sum(x * y_tr, axis=-1) 1051 1052 if y_tr.shape.is_fully_defined(): 1053 zeros = array_ops.zeros( 1054 list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype) 1055 superdiag = math_ops.reduce_sum( 1056 x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1) 1057 subdiag = math_ops.reduce_sum( 1058 x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1) 1059 else: 1060 rank = array_ops.rank(y_tr) 1061 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 1062 superdiag_pad = array_ops.concat( 1063 (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0) 1064 superdiag = math_ops.reduce_sum( 1065 x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1) 1066 subdiag_pad = array_ops.concat( 1067 (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0) 1068 subdiag = math_ops.reduce_sum( 1069 x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1) 1070 return array_ops.stack([superdiag, diag, subdiag], axis=-2) 1071