1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Construct the Kronecker product of one or more `LinearOperators`.""" 16 17from tensorflow.python.framework import common_shapes 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import errors 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.linalg import linalg_impl as linalg 28from tensorflow.python.ops.linalg import linear_operator 29from tensorflow.python.util.tf_export import tf_export 30 31__all__ = ["LinearOperatorKronecker"] 32 33 34def _prefer_static_shape(x): 35 if x.shape.is_fully_defined(): 36 return x.shape 37 return array_ops.shape(x) 38 39 40def _prefer_static_concat_shape(first_shape, second_shape_int_list): 41 """Concatenate a shape with a list of integers as statically as possible. 42 43 Args: 44 first_shape: `TensorShape` or `Tensor` instance. If a `TensorShape`, 45 `first_shape.is_fully_defined()` must return `True`. 46 second_shape_int_list: `list` of scalar integer `Tensor`s. 47 48 Returns: 49 `Tensor` representing concatenating `first_shape` and 50 `second_shape_int_list` as statically as possible. 51 """ 52 second_shape_int_list_static = [ 53 tensor_util.constant_value(s) for s in second_shape_int_list] 54 if (isinstance(first_shape, tensor_shape.TensorShape) and 55 all(s is not None for s in second_shape_int_list_static)): 56 return first_shape.concatenate(second_shape_int_list_static) 57 return array_ops.concat([first_shape, second_shape_int_list], axis=0) 58 59 60@tf_export("linalg.LinearOperatorKronecker") 61@linear_operator.make_composite_tensor 62class LinearOperatorKronecker(linear_operator.LinearOperator): 63 """Kronecker product between two `LinearOperators`. 64 65 This operator composes one or more linear operators `[op1,...,opJ]`, 66 building a new `LinearOperator` representing the Kronecker product: 67 `op1 x op2 x .. opJ` (we omit parentheses as the Kronecker product is 68 associative). 69 70 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the composed operator 71 will have shape equal to `broadcast_batch_shape + [prod M_j, prod N_j]`, 72 where the product is over all operators. 73 74 ```python 75 # Create a 4 x 4 linear operator composed of two 2 x 2 operators. 76 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 77 operator_2 = LinearOperatorFullMatrix([[1., 0.], [2., 1.]]) 78 operator = LinearOperatorKronecker([operator_1, operator_2]) 79 80 operator.to_dense() 81 ==> [[1., 0., 2., 0.], 82 [2., 1., 4., 2.], 83 [3., 0., 4., 0.], 84 [6., 3., 8., 4.]] 85 86 operator.shape 87 ==> [4, 4] 88 89 operator.log_abs_determinant() 90 ==> scalar Tensor 91 92 x = ... Shape [4, 2] Tensor 93 operator.matmul(x) 94 ==> Shape [4, 2] Tensor 95 96 # Create a [2, 3] batch of 4 x 5 linear operators. 97 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5]) 98 operator_45 = LinearOperatorFullMatrix(matrix) 99 100 # Create a [2, 3] batch of 5 x 6 linear operators. 101 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6]) 102 operator_56 = LinearOperatorFullMatrix(matrix_56) 103 104 # Compose to create a [2, 3] batch of 20 x 30 operators. 105 operator_large = LinearOperatorKronecker([operator_45, operator_56]) 106 107 # Create a shape [2, 3, 20, 2] vector. 108 x = tf.random.normal(shape=[2, 3, 6, 2]) 109 operator_large.matmul(x) 110 ==> Shape [2, 3, 30, 2] Tensor 111 ``` 112 113 #### Performance 114 115 The performance of `LinearOperatorKronecker` on any operation is equal to 116 the sum of the individual operators' operations. 117 118 #### Matrix property hints 119 120 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 121 for `X = non_singular, self_adjoint, positive_definite, square`. 122 These have the following meaning: 123 124 * If `is_X == True`, callers should expect the operator to have the 125 property `X`. This is a promise that should be fulfilled, but is *not* a 126 runtime assert. For example, finite floating point precision may result 127 in these promises being violated. 128 * If `is_X == False`, callers should expect the operator to not have `X`. 129 * If `is_X == None` (the default), callers should have no expectation either 130 way. 131 """ 132 133 def __init__(self, 134 operators, 135 is_non_singular=None, 136 is_self_adjoint=None, 137 is_positive_definite=None, 138 is_square=None, 139 name=None): 140 r"""Initialize a `LinearOperatorKronecker`. 141 142 `LinearOperatorKronecker` is initialized with a list of operators 143 `[op_1,...,op_J]`. 144 145 Args: 146 operators: Iterable of `LinearOperator` objects, each with 147 the same `dtype` and composable shape, representing the Kronecker 148 factors. 149 is_non_singular: Expect that this operator is non-singular. 150 is_self_adjoint: Expect that this operator is equal to its hermitian 151 transpose. 152 is_positive_definite: Expect that this operator is positive definite, 153 meaning the quadratic form `x^H A x` has positive real part for all 154 nonzero `x`. Note that we do not require the operator to be 155 self-adjoint to be positive-definite. See: 156 https://en.wikipedia.org/wiki/Positive-definite_matrix\ 157 #Extension_for_non_symmetric_matrices 158 is_square: Expect that this operator acts like square [batch] matrices. 159 name: A name for this `LinearOperator`. Default is the individual 160 operators names joined with `_x_`. 161 162 Raises: 163 TypeError: If all operators do not have the same `dtype`. 164 ValueError: If `operators` is empty. 165 """ 166 parameters = dict( 167 operators=operators, 168 is_non_singular=is_non_singular, 169 is_self_adjoint=is_self_adjoint, 170 is_positive_definite=is_positive_definite, 171 is_square=is_square, 172 name=name 173 ) 174 175 # Validate operators. 176 check_ops.assert_proper_iterable(operators) 177 operators = list(operators) 178 if not operators: 179 raise ValueError(f"Argument `operators` must be a list of >=1 operators. " 180 f"Received: {operators}.") 181 self._operators = operators 182 183 # Validate dtype. 184 dtype = operators[0].dtype 185 for operator in operators: 186 if operator.dtype != dtype: 187 name_type = (str((o.name, o.dtype)) for o in operators) 188 raise TypeError( 189 f"Expected every operation in argument `operators` to have the " 190 f"same dtype. Received {list(name_type)}.") 191 192 # Auto-set and check hints. 193 # A Kronecker product is invertible, if and only if all factors are 194 # invertible. 195 if all(operator.is_non_singular for operator in operators): 196 if is_non_singular is False: 197 raise ValueError( 198 f"The Kronecker product of non-singular operators is always " 199 f"non-singular. Expected argument `is_non_singular` to be True. " 200 f"Received: {is_non_singular}.") 201 is_non_singular = True 202 203 if all(operator.is_self_adjoint for operator in operators): 204 if is_self_adjoint is False: 205 raise ValueError( 206 f"The Kronecker product of self-adjoint operators is always " 207 f"self-adjoint. Expected argument `is_self_adjoint` to be True. " 208 f"Received: {is_self_adjoint}.") 209 is_self_adjoint = True 210 211 # The eigenvalues of a Kronecker product are equal to the products of eigen 212 # values of the corresponding factors. 213 if all(operator.is_positive_definite for operator in operators): 214 if is_positive_definite is False: 215 raise ValueError( 216 f"The Kronecker product of positive-definite operators is always " 217 f"positive-definite. Expected argument `is_positive_definite` to " 218 f"be True. Received: {is_positive_definite}.") 219 is_positive_definite = True 220 221 if name is None: 222 name = operators[0].name 223 for operator in operators[1:]: 224 name += "_x_" + operator.name 225 with ops.name_scope(name): 226 super(LinearOperatorKronecker, self).__init__( 227 dtype=dtype, 228 is_non_singular=is_non_singular, 229 is_self_adjoint=is_self_adjoint, 230 is_positive_definite=is_positive_definite, 231 is_square=is_square, 232 parameters=parameters, 233 name=name) 234 235 @property 236 def operators(self): 237 return self._operators 238 239 def _shape(self): 240 # Get final matrix shape. 241 domain_dimension = self.operators[0].domain_dimension 242 for operator in self.operators[1:]: 243 domain_dimension = domain_dimension * operator.domain_dimension 244 245 range_dimension = self.operators[0].range_dimension 246 for operator in self.operators[1:]: 247 range_dimension = range_dimension * operator.range_dimension 248 249 matrix_shape = tensor_shape.TensorShape([ 250 range_dimension, domain_dimension]) 251 252 # Get broadcast batch shape. 253 # broadcast_shape checks for compatibility. 254 batch_shape = self.operators[0].batch_shape 255 for operator in self.operators[1:]: 256 batch_shape = common_shapes.broadcast_shape( 257 batch_shape, operator.batch_shape) 258 259 return batch_shape.concatenate(matrix_shape) 260 261 def _shape_tensor(self): 262 domain_dimension = self.operators[0].domain_dimension_tensor() 263 for operator in self.operators[1:]: 264 domain_dimension = domain_dimension * operator.domain_dimension_tensor() 265 266 range_dimension = self.operators[0].range_dimension_tensor() 267 for operator in self.operators[1:]: 268 range_dimension = range_dimension * operator.range_dimension_tensor() 269 270 matrix_shape = [range_dimension, domain_dimension] 271 272 # Get broadcast batch shape. 273 # broadcast_shape checks for compatibility. 274 batch_shape = self.operators[0].batch_shape_tensor() 275 for operator in self.operators[1:]: 276 batch_shape = array_ops.broadcast_dynamic_shape( 277 batch_shape, operator.batch_shape_tensor()) 278 279 return array_ops.concat((batch_shape, matrix_shape), 0) 280 281 def _solve_matmul_internal( 282 self, 283 x, 284 solve_matmul_fn, 285 adjoint=False, 286 adjoint_arg=False): 287 # We heavily rely on Roth's column Lemma [1]: 288 # (A x B) * vec X = vec BXA^T 289 # where vec stacks all the columns of the matrix under each other. 290 # In our case, we use a variant of the lemma that is row-major 291 # friendly: (A x B) * vec' X = vec' AXB^T 292 # Where vec' reshapes a matrix into a vector. We can repeatedly apply this 293 # for a collection of kronecker products. 294 # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can 295 # use the above to compute multiplications, solves with any composition of 296 # transposes. 297 output = x 298 299 if adjoint_arg: 300 if self.dtype.is_complex: 301 output = math_ops.conj(output) 302 else: 303 output = linalg.transpose(output) 304 305 for o in reversed(self.operators): 306 # Statically compute the reshape. 307 if adjoint: 308 operator_dimension = o.range_dimension_tensor() 309 else: 310 operator_dimension = o.domain_dimension_tensor() 311 output_shape = _prefer_static_shape(output) 312 313 if tensor_util.constant_value(operator_dimension) is not None: 314 operator_dimension = tensor_util.constant_value(operator_dimension) 315 if output.shape[-2] is not None and output.shape[-1] is not None: 316 dim = int(output.shape[-2] * output_shape[-1] // operator_dimension) 317 else: 318 dim = math_ops.cast( 319 output_shape[-2] * output_shape[-1] // operator_dimension, 320 dtype=dtypes.int32) 321 322 output_shape = _prefer_static_concat_shape( 323 output_shape[:-2], [dim, operator_dimension]) 324 output = array_ops.reshape(output, shape=output_shape) 325 326 # Conjugate because we are trying to compute A @ B^T, but 327 # `LinearOperator` only supports `adjoint_arg`. 328 if self.dtype.is_complex: 329 output = math_ops.conj(output) 330 331 output = solve_matmul_fn( 332 o, output, adjoint=adjoint, adjoint_arg=True) 333 334 if adjoint_arg: 335 col_dim = _prefer_static_shape(x)[-2] 336 else: 337 col_dim = _prefer_static_shape(x)[-1] 338 339 if adjoint: 340 row_dim = self.domain_dimension_tensor() 341 else: 342 row_dim = self.range_dimension_tensor() 343 344 matrix_shape = [row_dim, col_dim] 345 346 output = array_ops.reshape( 347 output, 348 _prefer_static_concat_shape( 349 _prefer_static_shape(output)[:-2], matrix_shape)) 350 351 if x.shape.is_fully_defined(): 352 if adjoint_arg: 353 column_dim = x.shape[-2] 354 else: 355 column_dim = x.shape[-1] 356 broadcast_batch_shape = common_shapes.broadcast_shape( 357 x.shape[:-2], self.batch_shape) 358 if adjoint: 359 matrix_dimensions = [self.domain_dimension, column_dim] 360 else: 361 matrix_dimensions = [self.range_dimension, column_dim] 362 363 output.set_shape(broadcast_batch_shape.concatenate( 364 matrix_dimensions)) 365 366 return output 367 368 def _matmul(self, x, adjoint=False, adjoint_arg=False): 369 def matmul_fn(o, x, adjoint, adjoint_arg): 370 return o.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 371 return self._solve_matmul_internal( 372 x=x, 373 solve_matmul_fn=matmul_fn, 374 adjoint=adjoint, 375 adjoint_arg=adjoint_arg) 376 377 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 378 def solve_fn(o, rhs, adjoint, adjoint_arg): 379 return o.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 380 return self._solve_matmul_internal( 381 x=rhs, 382 solve_matmul_fn=solve_fn, 383 adjoint=adjoint, 384 adjoint_arg=adjoint_arg) 385 386 def _determinant(self): 387 # Note that we have |X1 x X2| = |X1| ** n * |X2| ** m, where X1 is an m x m 388 # matrix, and X2 is an n x n matrix. We can iteratively apply this property 389 # to get the determinant of |X1 x X2 x X3 ...|. If T is the product of the 390 # domain dimension of all operators, then we have: 391 # |X1 x X2 x X3 ...| = 392 # |X1| ** (T / m) * |X2 x X3 ... | ** m = 393 # |X1| ** (T / m) * |X2| ** (m * (T / m) / n) * ... = 394 # |X1| ** (T / m) * |X2| ** (T / n) * | X3 x X4... | ** (m * n) 395 # And by doing induction we have product(|X_i| ** (T / dim(X_i))). 396 total = self.domain_dimension_tensor() 397 determinant = 1. 398 for operator in self.operators: 399 determinant = determinant * operator.determinant() ** math_ops.cast( 400 total / operator.domain_dimension_tensor(), 401 dtype=operator.dtype) 402 return determinant 403 404 def _log_abs_determinant(self): 405 # This will be sum((total / dim(x_i)) * log |X_i|) 406 total = self.domain_dimension_tensor() 407 log_abs_det = 0. 408 for operator in self.operators: 409 log_abs_det += operator.log_abs_determinant() * math_ops.cast( 410 total / operator.domain_dimension_tensor(), 411 dtype=operator.dtype) 412 return log_abs_det 413 414 def _trace(self): 415 # tr(A x B) = tr(A) * tr(B) 416 trace = 1. 417 for operator in self.operators: 418 trace = trace * operator.trace() 419 return trace 420 421 def _diag_part(self): 422 diag_part = self.operators[0].diag_part() 423 for operator in self.operators[1:]: 424 diag_part = diag_part[..., :, array_ops.newaxis] 425 op_diag_part = operator.diag_part()[..., array_ops.newaxis, :] 426 diag_part = diag_part * op_diag_part 427 diag_part = array_ops.reshape( 428 diag_part, 429 shape=array_ops.concat( 430 [array_ops.shape(diag_part)[:-2], [-1]], axis=0)) 431 if self.range_dimension > self.domain_dimension: 432 diag_dimension = self.domain_dimension 433 else: 434 diag_dimension = self.range_dimension 435 diag_part.set_shape( 436 self.batch_shape.concatenate(diag_dimension)) 437 return diag_part 438 439 def _to_dense(self): 440 product = self.operators[0].to_dense() 441 for operator in self.operators[1:]: 442 # Product has shape [B, R1, 1, C1, 1]. 443 product = product[ 444 ..., :, array_ops.newaxis, :, array_ops.newaxis] 445 # Operator has shape [B, 1, R2, 1, C2]. 446 op_to_mul = operator.to_dense()[ 447 ..., array_ops.newaxis, :, array_ops.newaxis, :] 448 # This is now [B, R1, R2, C1, C2]. 449 product = product * op_to_mul 450 # Now merge together dimensions to get [B, R1 * R2, C1 * C2]. 451 product_shape = _prefer_static_shape(product) 452 shape = _prefer_static_concat_shape( 453 product_shape[:-4], 454 [product_shape[-4] * product_shape[-3], 455 product_shape[-2] * product_shape[-1]]) 456 457 product = array_ops.reshape(product, shape=shape) 458 product.set_shape(self.shape) 459 return product 460 461 def _eigvals(self): 462 # This will be the kronecker product of all the eigenvalues. 463 # Note: It doesn't matter which kronecker product it is, since every 464 # kronecker product of the same matrices are similar. 465 eigvals = [operator.eigvals() for operator in self.operators] 466 # Now compute the kronecker product 467 product = eigvals[0] 468 for eigval in eigvals[1:]: 469 # Product has shape [B, R1, 1]. 470 product = product[..., array_ops.newaxis] 471 # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2]. 472 product = product * eigval[..., array_ops.newaxis, :] 473 # Reshape to [B, R1 * R2] 474 product = array_ops.reshape( 475 product, 476 shape=array_ops.concat([array_ops.shape(product)[:-2], [-1]], axis=0)) 477 product.set_shape(self.shape[:-1]) 478 return product 479 480 def _assert_non_singular(self): 481 if all(operator.is_square for operator in self.operators): 482 asserts = [operator.assert_non_singular() for operator in self.operators] 483 return control_flow_ops.group(asserts) 484 else: 485 raise errors.InvalidArgumentError( 486 node_def=None, 487 op=None, 488 message="All Kronecker factors must be square for the product to be " 489 "invertible. Expected hint `is_square` to be True for every operator " 490 "in argument `operators`.") 491 492 def _assert_self_adjoint(self): 493 if all(operator.is_square for operator in self.operators): 494 asserts = [operator.assert_self_adjoint() for operator in self.operators] 495 return control_flow_ops.group(asserts) 496 else: 497 raise errors.InvalidArgumentError( 498 node_def=None, 499 op=None, 500 message="All Kronecker factors must be square for the product to be " 501 "invertible. Expected hint `is_square` to be True for every operator " 502 "in argument `operators`.") 503 504 @property 505 def _composite_tensor_fields(self): 506 return ("operators",) 507 508 @property 509 def _experimental_parameter_ndims_to_matrix_ndims(self): 510 return {"operators": [0] * len(self.operators)} 511