1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Internal utilities for `LinearOperator` classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import linalg_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.linalg import linalg_impl as linalg 31 32 33def assert_no_entries_with_modulus_zero( 34 x, message=None, name="assert_no_entries_with_modulus_zero"): 35 """Returns `Op` that asserts Tensor `x` has no entries with modulus zero. 36 37 Args: 38 x: Numeric `Tensor`, real, integer, or complex. 39 message: A string message to prepend to failure message. 40 name: A name to give this `Op`. 41 42 Returns: 43 An `Op` that asserts `x` has no entries with modulus zero. 44 """ 45 with ops.name_scope(name, values=[x]): 46 x = ops.convert_to_tensor(x, name="x") 47 dtype = x.dtype.base_dtype 48 should_be_nonzero = math_ops.abs(x) 49 zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype) 50 return check_ops.assert_less(zero, should_be_nonzero, message=message) 51 52 53def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): 54 """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts. 55 56 Args: 57 x: Numeric `Tensor`, real, integer, or complex. 58 message: A string message to prepend to failure message. 59 name: A name to give this `Op`. 60 61 Returns: 62 An `Op` that asserts `x` has no entries with modulus zero. 63 """ 64 with ops.name_scope(name, values=[x]): 65 x = ops.convert_to_tensor(x, name="x") 66 dtype = x.dtype.base_dtype 67 68 if dtype.is_floating: 69 return control_flow_ops.no_op() 70 71 zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype) 72 return check_ops.assert_equal(zero, math_ops.imag(x), message=message) 73 74 75def assert_compatible_matrix_dimensions(operator, x): 76 """Assert that an argument to solve/matmul has proper domain dimension. 77 78 If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then 79 `operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an 80 `Assert` that "fires" if this is not the case. Static checks are already 81 done by the base class `LinearOperator`. 82 83 Args: 84 operator: `LinearOperator`. 85 x: `Tensor`. 86 87 Returns: 88 `Assert` `Op`. 89 """ 90 # Static checks are done in the base class. Only tensor asserts here. 91 assert_same_dd = check_ops.assert_equal( 92 array_ops.shape(x)[-2], 93 operator.domain_dimension_tensor(), 94 message=("Incompatible matrix dimensions. " 95 "shape[-2] of argument to be the same as this operator")) 96 97 return assert_same_dd 98 99 100def assert_is_batch_matrix(tensor): 101 """Static assert that `tensor` has rank `2` or higher.""" 102 sh = tensor.get_shape() 103 if sh.ndims is not None and sh.ndims < 2: 104 raise ValueError( 105 "Expected [batch] matrix to have at least two dimensions. Found: " 106 "%s" % tensor) 107 108 109def shape_tensor(shape, name=None): 110 """Convert Tensor using default type, unless empty list or tuple.""" 111 # Works just like random_ops._ShapeTensor. 112 if isinstance(shape, (tuple, list)) and not shape: 113 dtype = dtypes.int32 114 else: 115 dtype = None 116 return ops.convert_to_tensor(shape, dtype=dtype, name=name) 117 118 119################################################################################ 120# Broadcasting versions of common linear algebra functions. 121# TODO(b/77519145) Do this more efficiently in some special cases. 122################################################################################ 123 124 125def broadcast_matrix_batch_dims(batch_matrices, name=None): 126 """Broadcast leading dimensions of zero or more [batch] matrices. 127 128 Example broadcasting one batch dim of two simple matrices. 129 130 ```python 131 x = [[1, 2], 132 [3, 4]] # Shape [2, 2], no batch dims 133 134 y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] 135 136 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 137 138 x_bc 139 ==> [[[1, 2], 140 [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. 141 142 y_bc 143 ==> same as y 144 ``` 145 146 Example broadcasting many batch dims 147 148 ```python 149 x = tf.random_normal(shape=(2, 3, 1, 4, 4)) 150 y = tf.random_normal(shape=(1, 3, 2, 5, 5)) 151 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 152 153 x_bc.shape 154 ==> (2, 3, 2, 4, 4) 155 156 y_bc.shape 157 ==> (2, 3, 2, 5, 5) 158 ``` 159 160 Args: 161 batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. 162 name: A string name to prepend to created ops. 163 164 Returns: 165 bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing 166 the values from `batch_matrices[i]`, with possibly broadcast batch dims. 167 168 Raises: 169 ValueError: If any input `Tensor` is statically determined to have less 170 than two dimensions. 171 """ 172 with ops.name_scope( 173 name or "broadcast_matrix_batch_dims", values=batch_matrices): 174 check_ops.assert_proper_iterable(batch_matrices) 175 batch_matrices = list(batch_matrices) 176 177 for i, mat in enumerate(batch_matrices): 178 batch_matrices[i] = ops.convert_to_tensor(mat) 179 assert_is_batch_matrix(batch_matrices[i]) 180 181 if len(batch_matrices) < 2: 182 return batch_matrices 183 184 # Try static broadcasting. 185 # bcast_batch_shape is the broadcast batch shape of ALL matrices. 186 # E.g. if batch_matrices = [x, y], with 187 # x.shape = [2, j, k] (batch shape = [2]) 188 # y.shape = [3, 1, l, m] (batch shape = [3, 1]) 189 # ==> bcast_batch_shape = [3, 2] 190 bcast_batch_shape = batch_matrices[0].get_shape()[:-2] 191 for mat in batch_matrices[1:]: 192 bcast_batch_shape = array_ops.broadcast_static_shape( 193 bcast_batch_shape, 194 mat.get_shape()[:-2]) 195 if bcast_batch_shape.is_fully_defined(): 196 # The [1, 1] at the end will broadcast with anything. 197 bcast_shape = bcast_batch_shape.concatenate([1, 1]) 198 for i, mat in enumerate(batch_matrices): 199 if mat.get_shape()[:-2] != bcast_batch_shape: 200 batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) 201 return batch_matrices 202 203 # Since static didn't work, do dynamic, which always copies data. 204 bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] 205 for mat in batch_matrices[1:]: 206 bcast_batch_shape = array_ops.broadcast_dynamic_shape( 207 bcast_batch_shape, 208 array_ops.shape(mat)[:-2]) 209 bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0) 210 for i, mat in enumerate(batch_matrices): 211 batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) 212 213 return batch_matrices 214 215 216def _broadcast_to_shape(x, shape): 217 return x + array_ops.zeros(shape=shape, dtype=x.dtype) 218 219 220def cholesky_solve_with_broadcast(chol, rhs, name=None): 221 """Solve systems of linear equations.""" 222 with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]): 223 chol, rhs = broadcast_matrix_batch_dims([chol, rhs]) 224 return linalg_ops.cholesky_solve(chol, rhs) 225 226 227def matmul_with_broadcast(a, 228 b, 229 transpose_a=False, 230 transpose_b=False, 231 adjoint_a=False, 232 adjoint_b=False, 233 a_is_sparse=False, 234 b_is_sparse=False, 235 name=None): 236 """Multiplies matrix `a` by matrix `b`, producing `a @ b`. 237 238 Works identically to `tf.matmul`, but broadcasts batch dims 239 of `a` and `b` if they are determined statically to be different, or if static 240 shapes are not fully defined. Attempts are made to avoid unnecessary 241 replication of data, but this is not always possible. 242 243 The inputs must be matrices (or tensors of rank > 2, representing batches of 244 matrices). 245 246 Both matrices must be of the same type. The supported types are: 247 `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`. 248 249 Either matrix can be transposed or adjointed (conjugated and transposed) on 250 the fly by setting one of the corresponding flag to `True`. These are `False` 251 by default. 252 253 If one or both of the matrices contain a lot of zeros, a more efficient 254 multiplication algorithm can be used by setting the corresponding 255 `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default. 256 This optimization is only available for plain matrices (rank-2 tensors) with 257 datatypes `bfloat16` or `float32`. 258 259 For example: 260 261 ```python 262 # A 2-batch of 3x4 matrices 263 a = tf.random_normal(shape=(2, 3, 4)) 264 265 # A single 4x5 matrix 266 b = tf.random_normal(shape=(4, 5)) 267 268 result = matmul_with_broadcast(a, b) 269 270 result.shape 271 ==> (2, 3, 5) 272 273 result[0,...] 274 ==> tf.matmul(a[0,...], b) 275 276 result[1,...] 277 ==> tf.matmul(a[1,...], b) 278 ``` 279 280 Args: 281 a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`, 282 `complex128` and `rank > 1`. 283 b: `Tensor` with same type as `a` having compatible matrix dimensions and 284 broadcastable batch dimensions. 285 transpose_a: If `True`, `a` is transposed before multiplication. 286 transpose_b: If `True`, `b` is transposed before multiplication. 287 adjoint_a: If `True`, `a` is conjugated and transposed before 288 multiplication. 289 adjoint_b: If `True`, `b` is conjugated and transposed before 290 multiplication. 291 a_is_sparse: If `True`, `a` is treated as a sparse matrix. 292 b_is_sparse: If `True`, `b` is treated as a sparse matrix. 293 name: Name for the operation (optional). 294 295 Returns: 296 A `Tensor` of the same type as `a` and `b` where each inner-most matrix is 297 the product of the corresponding matrices in `a` and `b`, e.g. if all 298 transpose or adjoint attributes are `False`: 299 300 The leading shape of `output` is the result of broadcasting the leading 301 dimensions of `a` and `b`. 302 303 `output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]), 304 for all indices i, j. 305 306 Note: This is matrix product, not element-wise product. 307 308 309 Raises: 310 ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b 311 are both set to True. 312 """ 313 with ops.name_scope(name, "MatMulWithBroadcast", [a, b]): 314 a = ops.convert_to_tensor(a, name="a") 315 b = ops.convert_to_tensor(b, name="b", dtype=a.dtype) 316 317 # If either a or b has extra dims, we can reshape to get rid of them. 318 a, b, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( 319 a, 320 b, 321 transpose_a=transpose_a, 322 transpose_b=transpose_b, 323 adjoint_a=adjoint_a, 324 adjoint_b=adjoint_b) 325 326 # This will broadcast by brute force if we still need to. 327 a, b = broadcast_matrix_batch_dims([a, b]) 328 329 a_times_b = math_ops.matmul( 330 a, 331 b, 332 transpose_a=transpose_a and still_need_to_transpose, 333 transpose_b=transpose_b and still_need_to_transpose, 334 adjoint_a=adjoint_a and still_need_to_transpose, 335 adjoint_b=adjoint_b and still_need_to_transpose, 336 a_is_sparse=a_is_sparse, 337 b_is_sparse=b_is_sparse) 338 339 return reshape_inv(a_times_b) 340 341 342def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): 343 """Solve systems of linear equations.""" 344 with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): 345 matrix = ops.convert_to_tensor(matrix, name="matrix") 346 rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype) 347 348 # If either matrix/rhs has extra dims, we can reshape to get rid of them. 349 matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( 350 matrix, rhs, adjoint_a=adjoint) 351 352 # This will broadcast by brute force if we still need to. 353 matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) 354 355 solution = linalg_ops.matrix_solve( 356 matrix, rhs, adjoint=adjoint and still_need_to_transpose) 357 358 return reshape_inv(solution) 359 360 361def matrix_triangular_solve_with_broadcast(matrix, 362 rhs, 363 lower=True, 364 adjoint=False, 365 name=None): 366 """Solves triangular systems of linear equations with by backsubstitution. 367 368 Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims 369 of `matrix` and `rhs` (by replicating) if they are determined statically to be 370 different, or if static shapes are not fully defined. Thus, this may result 371 in an inefficient replication of data. 372 373 Args: 374 matrix: A Tensor. Must be one of the following types: 375 `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`. 376 rhs: A `Tensor`. Must have the same `dtype` as `matrix`. 377 Shape is `[..., M, K]`. 378 lower: An optional `bool`. Defaults to `True`. Indicates whether the 379 innermost matrices in `matrix` are lower or upper triangular. 380 adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve 381 with matrix or its (block-wise) adjoint. 382 name: A name for the operation (optional). 383 384 Returns: 385 `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`. 386 """ 387 with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]): 388 matrix = ops.convert_to_tensor(matrix, name="matrix") 389 rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype) 390 391 # If either matrix/rhs has extra dims, we can reshape to get rid of them. 392 matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( 393 matrix, rhs, adjoint_a=adjoint) 394 395 # lower indicates whether the matrix is lower triangular. If we have 396 # manually taken adjoint inside _reshape_for_efficiency, it is now upper tri 397 if not still_need_to_transpose and adjoint: 398 lower = not lower 399 400 # This will broadcast by brute force if we still need to. 401 matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) 402 403 solution = linalg_ops.matrix_triangular_solve( 404 matrix, 405 rhs, 406 lower=lower, 407 adjoint=adjoint and still_need_to_transpose) 408 409 return reshape_inv(solution) 410 411 412def _reshape_for_efficiency(a, 413 b, 414 transpose_a=False, 415 transpose_b=False, 416 adjoint_a=False, 417 adjoint_b=False): 418 """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" 419 def identity(x): 420 return x 421 422 # At this point, we have not taken transpose/adjoint of a/b. 423 still_need_to_transpose = True 424 425 if a.shape.ndims is None or b.shape.ndims is None: 426 return a, b, identity, still_need_to_transpose 427 428 # This could be handled in the future, but seems less common. 429 if a.shape.ndims >= b.shape.ndims: 430 return a, b, identity, still_need_to_transpose 431 432 # From now on, we might modify b, but will not modify a. 433 434 # Suppose: 435 # a.shape = C + [m, n], b.shape = 436 # b.shape = S + C + [n, r] 437 b_extra_ndims = b.shape.ndims - a.shape.ndims 438 439 # b_extra_sh = S, b_main_sh = C + [n, r] 440 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 441 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 442 443 # No reason to flip unless the extra dims of b are big enough. Why? 444 # Assume adjoint/transpose = False. Then... 445 # By not flipping, we have to replicate a to shape 446 # b_extra_sh + a.shape, 447 # which could use extra memory. But in all cases, the final output has shape 448 # b_extra_sh + a.shape[:-1] + [b.shape[-1]] 449 # So we only end up creating a larger object if the end dim of b is smaller 450 # than the end dim of a. This often happens, e.g. if b was a vector that was 451 # expanded to a matrix (by appending a singleton). 452 453 # Since adjoint/transpose may not be False, we must make adjustments here. 454 # The dim of b that holds the multiple equations. 455 a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] 456 b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] 457 b_extra_sz_ = ( 458 np.prod(b.shape[:b_extra_ndims].as_list()) 459 if b.shape[:b_extra_ndims].is_fully_defined() else None) 460 if (a_domain_sz_ is not None and b_eq_sz_ is not None and 461 b_extra_sz_ is not None): 462 if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: 463 return a, b, identity, still_need_to_transpose 464 465 # At this point, we're flipping for sure! 466 # Any transposes/adjoints will happen here explicitly, rather than in calling 467 # code. Why? To avoid having to write separate complex code for each case. 468 if adjoint_a: 469 a = linalg.adjoint(a) 470 elif transpose_a: 471 a = linalg.transpose(a) 472 if adjoint_b: 473 b = linalg.adjoint(b) 474 elif transpose_b: 475 b = linalg.transpose(b) 476 still_need_to_transpose = False 477 478 # Recompute shapes, since the transpose/adjoint may have changed them. 479 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 480 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 481 482 # Permutation to put the extra dims at the end. 483 perm = ( 484 np.concatenate( 485 (np.arange(b_extra_ndims, b.shape.ndims), 486 np.arange(0, b_extra_ndims)), 0)) 487 b_extra_on_end = array_ops.transpose(b, perm=perm) 488 489 # Now squash this end into one long dim. 490 b_squashed_end = array_ops.reshape( 491 b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) 492 493 def reshape_inv(y): 494 # Expand the extra dims hanging off the end, "b_extra_sh". 495 # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y 496 # Could have different batch dims than a and b, because of broadcasting. 497 y_extra_shape = array_ops.concat( 498 (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) 499 y_extra_on_end = array_ops.reshape(y, y_extra_shape) 500 inverse_perm = np.argsort(perm) 501 return array_ops.transpose(y_extra_on_end, perm=inverse_perm) 502 503 return a, b_squashed_end, reshape_inv, still_need_to_transpose 504