1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a zero matrix.""" 16 17import numpy as np 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import errors 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import check_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops.linalg import linalg_impl as linalg 29from tensorflow.python.ops.linalg import linear_operator 30from tensorflow.python.ops.linalg import linear_operator_util 31from tensorflow.python.util.tf_export import tf_export 32 33__all__ = [ 34 "LinearOperatorZeros", 35] 36 37 38@tf_export("linalg.LinearOperatorZeros") 39@linear_operator.make_composite_tensor 40class LinearOperatorZeros(linear_operator.LinearOperator): 41 """`LinearOperator` acting like a [batch] zero matrix. 42 43 This operator acts like a [batch] zero matrix `A` with shape 44 `[B1,...,Bb, N, M]` for some `b >= 0`. The first `b` indices index a 45 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 46 an `N x M` matrix. This matrix `A` is not materialized, but for 47 purposes of broadcasting this shape will be relevant. 48 49 `LinearOperatorZeros` is initialized with `num_rows`, and optionally 50 `num_columns, `batch_shape`, and `dtype` arguments. If `num_columns` is 51 `None`, then this operator will be initialized as a square matrix. If 52 `batch_shape` is `None`, this operator efficiently passes through all 53 arguments. If `batch_shape` is provided, broadcasting may occur, which will 54 require making copies. 55 56 ```python 57 # Create a 2 x 2 zero matrix. 58 operator = LinearOperatorZero(num_rows=2, dtype=tf.float32) 59 60 operator.to_dense() 61 ==> [[0., 0.] 62 [0., 0.]] 63 64 operator.shape 65 ==> [2, 2] 66 67 operator.determinant() 68 ==> 0. 69 70 x = ... Shape [2, 4] Tensor 71 operator.matmul(x) 72 ==> Shape [2, 4] Tensor, same as x. 73 74 # Create a 2-batch of 2x2 zero matrices 75 operator = LinearOperatorZeros(num_rows=2, batch_shape=[2]) 76 operator.to_dense() 77 ==> [[[0., 0.] 78 [0., 0.]], 79 [[0., 0.] 80 [0., 0.]]] 81 82 # Here, even though the operator has a batch shape, the input is the same as 83 # the output, so x can be passed through without a copy. The operator is able 84 # to detect that no broadcast is necessary because both x and the operator 85 # have statically defined shape. 86 x = ... Shape [2, 2, 3] 87 operator.matmul(x) 88 ==> Shape [2, 2, 3] Tensor, same as tf.zeros_like(x) 89 90 # Here the operator and x have different batch_shape, and are broadcast. 91 # This requires a copy, since the output is different size than the input. 92 x = ... Shape [1, 2, 3] 93 operator.matmul(x) 94 ==> Shape [2, 2, 3] Tensor, equal to tf.zeros_like([x, x]) 95 ``` 96 97 ### Shape compatibility 98 99 This operator acts on [batch] matrix with compatible shape. 100 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 101 102 ``` 103 operator.shape = [B1,...,Bb] + [N, M], with b >= 0 104 x.shape = [C1,...,Cc] + [M, R], 105 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 106 ``` 107 108 #### Matrix property hints 109 110 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 111 for `X = non_singular, self_adjoint, positive_definite, square`. 112 These have the following meaning: 113 114 * If `is_X == True`, callers should expect the operator to have the 115 property `X`. This is a promise that should be fulfilled, but is *not* a 116 runtime assert. For example, finite floating point precision may result 117 in these promises being violated. 118 * If `is_X == False`, callers should expect the operator to not have `X`. 119 * If `is_X == None` (the default), callers should have no expectation either 120 way. 121 """ 122 123 def __init__(self, 124 num_rows, 125 num_columns=None, 126 batch_shape=None, 127 dtype=None, 128 is_non_singular=False, 129 is_self_adjoint=True, 130 is_positive_definite=False, 131 is_square=True, 132 assert_proper_shapes=False, 133 name="LinearOperatorZeros"): 134 r"""Initialize a `LinearOperatorZeros`. 135 136 The `LinearOperatorZeros` is initialized with arguments defining `dtype` 137 and shape. 138 139 This operator is able to broadcast the leading (batch) dimensions, which 140 sometimes requires copying data. If `batch_shape` is `None`, the operator 141 can take arguments of any batch shape without copying. See examples. 142 143 Args: 144 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 145 corresponding zero matrix. 146 num_columns: Scalar non-negative integer `Tensor`. Number of columns in 147 the corresponding zero matrix. If `None`, defaults to the value of 148 `num_rows`. 149 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading 150 dimensions. If `None`, this operator has no leading dimensions. 151 dtype: Data type of the matrix that this operator represents. 152 is_non_singular: Expect that this operator is non-singular. 153 is_self_adjoint: Expect that this operator is equal to its hermitian 154 transpose. 155 is_positive_definite: Expect that this operator is positive definite, 156 meaning the quadratic form `x^H A x` has positive real part for all 157 nonzero `x`. Note that we do not require the operator to be 158 self-adjoint to be positive-definite. See: 159 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 160 is_square: Expect that this operator acts like square [batch] matrices. 161 assert_proper_shapes: Python `bool`. If `False`, only perform static 162 checks that initialization and method arguments have proper shape. 163 If `True`, and static checks are inconclusive, add asserts to the graph. 164 name: A name for this `LinearOperator` 165 166 Raises: 167 ValueError: If `num_rows` is determined statically to be non-scalar, or 168 negative. 169 ValueError: If `num_columns` is determined statically to be non-scalar, 170 or negative. 171 ValueError: If `batch_shape` is determined statically to not be 1-D, or 172 negative. 173 ValueError: If any of the following is not `True`: 174 `{is_self_adjoint, is_non_singular, is_positive_definite}`. 175 """ 176 parameters = dict( 177 num_rows=num_rows, 178 num_columns=num_columns, 179 batch_shape=batch_shape, 180 dtype=dtype, 181 is_non_singular=is_non_singular, 182 is_self_adjoint=is_self_adjoint, 183 is_positive_definite=is_positive_definite, 184 is_square=is_square, 185 assert_proper_shapes=assert_proper_shapes, 186 name=name 187 ) 188 189 dtype = dtype or dtypes.float32 190 self._assert_proper_shapes = assert_proper_shapes 191 192 with ops.name_scope(name): 193 dtype = dtypes.as_dtype(dtype) 194 if not is_self_adjoint and is_square: 195 raise ValueError("A zero operator is always self adjoint.") 196 if is_non_singular: 197 raise ValueError("A zero operator is always singular.") 198 if is_positive_definite: 199 raise ValueError("A zero operator is always not positive-definite.") 200 201 super(LinearOperatorZeros, self).__init__( 202 dtype=dtype, 203 is_non_singular=is_non_singular, 204 is_self_adjoint=is_self_adjoint, 205 is_positive_definite=is_positive_definite, 206 is_square=is_square, 207 parameters=parameters, 208 name=name) 209 210 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 211 linear_operator_util.assert_not_ref_type(num_columns, "num_columns") 212 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape") 213 214 self._num_rows = linear_operator_util.shape_tensor( 215 num_rows, name="num_rows") 216 self._num_rows_static = tensor_util.constant_value(self._num_rows) 217 218 if num_columns is None: 219 num_columns = num_rows 220 221 self._num_columns = linear_operator_util.shape_tensor( 222 num_columns, name="num_columns") 223 self._num_columns_static = tensor_util.constant_value(self._num_columns) 224 225 self._check_domain_range_possibly_add_asserts() 226 227 if (self._num_rows_static is not None and 228 self._num_columns_static is not None): 229 if is_square and self._num_rows_static != self._num_columns_static: 230 raise ValueError( 231 "LinearOperatorZeros initialized as is_square=True, but got " 232 "num_rows({}) != num_columns({})".format( 233 self._num_rows_static, 234 self._num_columns_static)) 235 236 if batch_shape is None: 237 self._batch_shape_arg = None 238 else: 239 self._batch_shape_arg = linear_operator_util.shape_tensor( 240 batch_shape, name="batch_shape_arg") 241 self._batch_shape_static = tensor_util.constant_value( 242 self._batch_shape_arg) 243 self._check_batch_shape_possibly_add_asserts() 244 245 def _shape(self): 246 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 247 self._num_columns_static)) 248 if self._batch_shape_arg is None: 249 return matrix_shape 250 251 batch_shape = tensor_shape.TensorShape(self._batch_shape_static) 252 return batch_shape.concatenate(matrix_shape) 253 254 def _shape_tensor(self): 255 matrix_shape = array_ops.stack((self._num_rows, self._num_columns), axis=0) 256 if self._batch_shape_arg is None: 257 return matrix_shape 258 259 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) 260 261 def _assert_non_singular(self): 262 raise errors.InvalidArgumentError( 263 node_def=None, op=None, message="Zero operators are always " 264 "non-invertible.") 265 266 def _assert_positive_definite(self): 267 raise errors.InvalidArgumentError( 268 node_def=None, op=None, message="Zero operators are always " 269 "non-positive definite.") 270 271 def _assert_self_adjoint(self): 272 return control_flow_ops.no_op("assert_self_adjoint") 273 274 def _possibly_broadcast_batch_shape(self, x): 275 """Return 'x', possibly after broadcasting the leading dimensions.""" 276 # If we have no batch shape, our batch shape broadcasts with everything! 277 if self._batch_shape_arg is None: 278 return x 279 280 # Static attempt: 281 # If we determine that no broadcast is necessary, pass x through 282 # If we need a broadcast, add to an array of zeros. 283 # 284 # special_shape is the shape that, when broadcast with x's shape, will give 285 # the correct broadcast_shape. Note that 286 # We have already verified the second to last dimension of self.shape 287 # matches x's shape in assert_compatible_matrix_dimensions. 288 # Also, the final dimension of 'x' can have any shape. 289 # Therefore, the final two dimensions of special_shape are 1's. 290 special_shape = self.batch_shape.concatenate([1, 1]) 291 bshape = array_ops.broadcast_static_shape(x.shape, special_shape) 292 if special_shape.is_fully_defined(): 293 # bshape.is_fully_defined iff special_shape.is_fully_defined. 294 if bshape == x.shape: 295 return x 296 # Use the built in broadcasting of addition. 297 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 298 return x + zeros 299 300 # Dynamic broadcast: 301 # Always add to an array of zeros, rather than using a "cond", since a 302 # cond would require copying data from GPU --> CPU. 303 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) 304 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 305 return x + zeros 306 307 def _matmul(self, x, adjoint=False, adjoint_arg=False): 308 if self._assert_proper_shapes: 309 x = linalg.adjoint(x) if adjoint_arg else x 310 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 311 x = control_flow_ops.with_dependencies([aps], x) 312 if self.is_square: 313 # Note that adjoint has no effect since this matrix is self-adjoint. 314 if adjoint_arg: 315 output_shape = array_ops.concat([ 316 array_ops.shape(x)[:-2], 317 [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0) 318 else: 319 output_shape = array_ops.shape(x) 320 321 return self._possibly_broadcast_batch_shape( 322 array_ops.zeros(shape=output_shape, dtype=x.dtype)) 323 324 x_shape = array_ops.shape(x) 325 n = self._num_columns if adjoint else self._num_rows 326 m = x_shape[-2] if adjoint_arg else x_shape[-1] 327 328 output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0) 329 330 zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) 331 return self._possibly_broadcast_batch_shape(zeros) 332 333 def _determinant(self): 334 if self.batch_shape.is_fully_defined(): 335 return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype) 336 else: 337 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 338 339 def _trace(self): 340 # Get Tensor of all zeros of same shape as self.batch_shape. 341 if self.batch_shape.is_fully_defined(): 342 return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype) 343 else: 344 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 345 346 def _diag_part(self): 347 return self._zeros_diag() 348 349 def add_to_tensor(self, mat, name="add_to_tensor"): 350 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 351 352 Args: 353 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 354 name: A name to give this `Op`. 355 356 Returns: 357 A `Tensor` with broadcast shape and same `dtype` as `self`. 358 """ 359 return self._possibly_broadcast_batch_shape(mat) 360 361 def _check_domain_range_possibly_add_asserts(self): 362 """Static check of init arg `num_rows`, possibly add asserts.""" 363 # Possibly add asserts. 364 if self._assert_proper_shapes: 365 self._num_rows = control_flow_ops.with_dependencies([ 366 check_ops.assert_rank( 367 self._num_rows, 368 0, 369 message="Argument num_rows must be a 0-D Tensor."), 370 check_ops.assert_non_negative( 371 self._num_rows, 372 message="Argument num_rows must be non-negative."), 373 ], self._num_rows) 374 self._num_columns = control_flow_ops.with_dependencies([ 375 check_ops.assert_rank( 376 self._num_columns, 377 0, 378 message="Argument num_columns must be a 0-D Tensor."), 379 check_ops.assert_non_negative( 380 self._num_columns, 381 message="Argument num_columns must be non-negative."), 382 ], self._num_columns) 383 384 # Static checks. 385 if not self._num_rows.dtype.is_integer: 386 raise TypeError("Argument num_rows must be integer type. Found:" 387 " %s" % self._num_rows) 388 389 if not self._num_columns.dtype.is_integer: 390 raise TypeError("Argument num_columns must be integer type. Found:" 391 " %s" % self._num_columns) 392 393 num_rows_static = self._num_rows_static 394 num_columns_static = self._num_columns_static 395 396 if num_rows_static is not None: 397 if num_rows_static.ndim != 0: 398 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 399 " %s" % num_rows_static) 400 401 if num_rows_static < 0: 402 raise ValueError("Argument num_rows must be non-negative. Found:" 403 " %s" % num_rows_static) 404 if num_columns_static is not None: 405 if num_columns_static.ndim != 0: 406 raise ValueError("Argument num_columns must be a 0-D Tensor. Found:" 407 " %s" % num_columns_static) 408 409 if num_columns_static < 0: 410 raise ValueError("Argument num_columns must be non-negative. Found:" 411 " %s" % num_columns_static) 412 413 def _check_batch_shape_possibly_add_asserts(self): 414 """Static check of init arg `batch_shape`, possibly add asserts.""" 415 if self._batch_shape_arg is None: 416 return 417 418 # Possibly add asserts 419 if self._assert_proper_shapes: 420 self._batch_shape_arg = control_flow_ops.with_dependencies([ 421 check_ops.assert_rank( 422 self._batch_shape_arg, 423 1, 424 message="Argument batch_shape must be a 1-D Tensor."), 425 check_ops.assert_non_negative( 426 self._batch_shape_arg, 427 message="Argument batch_shape must be non-negative."), 428 ], self._batch_shape_arg) 429 430 # Static checks 431 if not self._batch_shape_arg.dtype.is_integer: 432 raise TypeError("Argument batch_shape must be integer type. Found:" 433 " %s" % self._batch_shape_arg) 434 435 if self._batch_shape_static is None: 436 return # Cannot do any other static checks. 437 438 if self._batch_shape_static.ndim != 1: 439 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" 440 " %s" % self._batch_shape_static) 441 442 if np.any(self._batch_shape_static < 0): 443 raise ValueError("Argument batch_shape must be non-negative. Found:" 444 "%s" % self._batch_shape_static) 445 446 def _min_matrix_dim(self): 447 """Minimum of domain/range dimension, if statically available, else None.""" 448 domain_dim = self.domain_dimension.value 449 range_dim = self.range_dimension.value 450 if domain_dim is None or range_dim is None: 451 return None 452 return min(domain_dim, range_dim) 453 454 def _min_matrix_dim_tensor(self): 455 """Minimum of domain/range dimension, as a tensor.""" 456 return math_ops.reduce_min(self.shape_tensor()[-2:]) 457 458 def _zeros_diag(self): 459 """Returns the diagonal of this operator as all zeros.""" 460 if self.shape.is_fully_defined(): 461 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) 462 else: 463 d_shape = array_ops.concat( 464 [self.batch_shape_tensor(), 465 [self._min_matrix_dim_tensor()]], axis=0) 466 467 return array_ops.zeros(shape=d_shape, dtype=self.dtype) 468 469 def _eigvals(self): 470 return self._zeros_diag() 471 472 @property 473 def _composite_tensor_prefer_static_fields(self): 474 return ("num_rows", "num_columns", "batch_shape") 475 476 @property 477 def _composite_tensor_fields(self): 478 return ("num_rows", "num_columns", "batch_shape", "dtype", 479 "assert_proper_shapes") 480 481 def __getitem__(self, slices): 482 # Slice the batch shape and return a new LinearOperatorIdentity. 483 # Use a proxy shape and slice it. Use this as the new batch shape 484 new_batch_shape = array_ops.shape( 485 array_ops.ones(self._batch_shape_arg)[slices]) 486 parameters = dict(self.parameters, batch_shape=new_batch_shape) 487 return LinearOperatorZeros(**parameters) 488 489