1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a permutation matrix.""" 16 17import numpy as np 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import tensor_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import sort_ops 26from tensorflow.python.ops.linalg import linalg_impl as linalg 27from tensorflow.python.ops.linalg import linear_operator 28from tensorflow.python.ops.linalg import linear_operator_util 29from tensorflow.python.util.tf_export import tf_export 30 31__all__ = ["LinearOperatorPermutation",] 32 33 34@tf_export("linalg.LinearOperatorPermutation") 35@linear_operator.make_composite_tensor 36class LinearOperatorPermutation(linear_operator.LinearOperator): 37 """`LinearOperator` acting like a [batch] of permutation matrices. 38 39 This operator acts like a [batch] of permutations with shape 40 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 41 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 42 an `N x N` matrix. This matrix `A` is not materialized, but for 43 purposes of broadcasting this shape will be relevant. 44 45 `LinearOperatorPermutation` is initialized with a (batch) vector. 46 47 A permutation, is defined by an integer vector `v` whose values are unique 48 and are in the range `[0, ... n]`. Applying the permutation on an input 49 matrix has the folllowing meaning: the value of `v` at index `i` 50 says to move the `v[i]`-th row of the input matrix to the `i`-th row. 51 Because all values are unique, this will result in a permutation of the 52 rows the input matrix. Note, that the permutation vector `v` has the same 53 semantics as `tf.transpose`. 54 55 ```python 56 # Create a 3 x 3 permutation matrix that swaps the last two columns. 57 vec = [0, 2, 1] 58 operator = LinearOperatorPermutation(vec) 59 60 operator.to_dense() 61 ==> [[1., 0., 0.] 62 [0., 0., 1.] 63 [0., 1., 0.]] 64 65 operator.shape 66 ==> [3, 3] 67 68 # This will be zero. 69 operator.log_abs_determinant() 70 ==> scalar Tensor 71 72 x = ... Shape [3, 4] Tensor 73 operator.matmul(x) 74 ==> Shape [3, 4] Tensor 75 ``` 76 77 #### Shape compatibility 78 79 This operator acts on [batch] matrix with compatible shape. 80 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 81 82 ``` 83 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 84 x.shape = [C1,...,Cc] + [N, R], 85 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 86 ``` 87 88 #### Matrix property hints 89 90 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 91 for `X = non_singular, self_adjoint, positive_definite, square`. 92 These have the following meaning: 93 94 * If `is_X == True`, callers should expect the operator to have the 95 property `X`. This is a promise that should be fulfilled, but is *not* a 96 runtime assert. For example, finite floating point precision may result 97 in these promises being violated. 98 * If `is_X == False`, callers should expect the operator to not have `X`. 99 * If `is_X == None` (the default), callers should have no expectation either 100 way. 101 """ 102 103 def __init__(self, 104 perm, 105 dtype=dtypes.float32, 106 is_non_singular=None, 107 is_self_adjoint=None, 108 is_positive_definite=None, 109 is_square=None, 110 name="LinearOperatorPermutation"): 111 r"""Initialize a `LinearOperatorPermutation`. 112 113 Args: 114 perm: Shape `[B1,...,Bb, N]` Integer `Tensor` with `b >= 0` 115 `N >= 0`. An integer vector that represents the permutation to apply. 116 Note that this argument is same as `tf.transpose`. However, this 117 permutation is applied on the rows, while the permutation in 118 `tf.transpose` is applied on the dimensions of the `Tensor`. `perm` 119 is required to have unique entries from `{0, 1, ... N-1}`. 120 dtype: The `dtype` of arguments to this operator. Default: `float32`. 121 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, 122 `complex128`. 123 is_non_singular: Expect that this operator is non-singular. 124 is_self_adjoint: Expect that this operator is equal to its hermitian 125 transpose. This is autoset to true 126 is_positive_definite: Expect that this operator is positive definite, 127 meaning the quadratic form `x^H A x` has positive real part for all 128 nonzero `x`. Note that we do not require the operator to be 129 self-adjoint to be positive-definite. See: 130 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 131 This is autoset to false. 132 is_square: Expect that this operator acts like square [batch] matrices. 133 This is autoset to true. 134 name: A name for this `LinearOperator`. 135 136 Raises: 137 ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is 138 not `False` or `is_square` is not `True`. 139 """ 140 parameters = dict( 141 perm=perm, 142 dtype=dtype, 143 is_non_singular=is_non_singular, 144 is_self_adjoint=is_self_adjoint, 145 is_positive_definite=is_positive_definite, 146 is_square=is_square, 147 name=name 148 ) 149 150 with ops.name_scope(name, values=[perm]): 151 self._perm = linear_operator_util.convert_nonref_to_tensor( 152 perm, name="perm") 153 self._check_perm(self._perm) 154 155 # Check and auto-set hints. 156 if is_non_singular is False: # pylint:disable=g-bool-id-comparison 157 raise ValueError(f"A Permutation operator is always non-singular. " 158 f"Expected argument `is_non_singular` to be True. " 159 f"Received: {is_non_singular}.") 160 161 if is_square is False: # pylint:disable=g-bool-id-comparison 162 raise ValueError(f"A Permutation operator is always square. " 163 f"Expected argument `is_square` to be True. " 164 f"Received: {is_square}.") 165 is_square = True 166 167 super(LinearOperatorPermutation, self).__init__( 168 dtype=dtype, 169 is_non_singular=is_non_singular, 170 is_self_adjoint=is_self_adjoint, 171 is_positive_definite=is_positive_definite, 172 is_square=is_square, 173 parameters=parameters, 174 name=name) 175 176 def _check_perm(self, perm): 177 """Static check of perm.""" 178 if (perm.shape.ndims is not None and perm.shape.ndims < 1): 179 raise ValueError(f"Argument `perm` must have at least 1 dimension. " 180 f"Received: {perm}.") 181 if not perm.dtype.is_integer: 182 raise TypeError(f"Argument `perm` must be integer dtype. " 183 f"Received: {perm}.") 184 # Check that the permutation satisfies the uniqueness constraint. 185 static_perm = tensor_util.constant_value(perm) 186 if static_perm is not None: 187 sorted_perm = np.sort(static_perm, axis=-1) 188 if np.any(sorted_perm != np.arange(0, static_perm.shape[-1])): 189 raise ValueError( 190 f"Argument `perm` must be a vector of unique integers from " 191 f"0 to {static_perm.shape[-1] - 1}.") 192 193 def _shape(self): 194 perm_shape = self._perm.shape 195 return perm_shape.concatenate(perm_shape[-1:]) 196 197 def _shape_tensor(self): 198 perm_shape = array_ops.shape(self._perm) 199 k = perm_shape[-1] 200 return array_ops.concat((perm_shape, [k]), 0) 201 202 def _assert_non_singular(self): 203 return control_flow_ops.no_op("assert_non_singular") 204 205 def _domain_dimension_tensor(self, perm=None): 206 perm = perm if perm is not None else self.perm 207 return array_ops.shape(perm)[-1] 208 209 def _matmul(self, x, adjoint=False, adjoint_arg=False): 210 perm = ops.convert_to_tensor_v2_with_dispatch(self.perm) 211 if adjoint and not self.is_self_adjoint: 212 # TODO(srvasude): invert_permutation doesn't work on batches so we use 213 # argsort. 214 perm = sort_ops.argsort(perm, axis=-1) 215 x = linalg.adjoint(x) if adjoint_arg else x 216 217 # We need to broadcast x and the permutation since tf.gather doesn't 218 # broadcast. 219 broadcast_shape = array_ops.broadcast_dynamic_shape( 220 array_ops.shape(x)[:-1], array_ops.shape(perm)) 221 k = array_ops.shape(x)[-1] 222 broadcast_x_shape = array_ops.concat([broadcast_shape, [k]], axis=-1) 223 x = array_ops.broadcast_to(x, broadcast_x_shape) 224 perm = array_ops.broadcast_to(perm, broadcast_shape) 225 226 m = array_ops.shape(x)[-2] 227 x = array_ops.reshape(x, [-1, m, k]) 228 perm = array_ops.reshape(perm, [-1, m]) 229 230 y = array_ops.gather(x, perm, axis=-2, batch_dims=1) 231 return array_ops.reshape(y, broadcast_x_shape) 232 233 # TODO(srvasude): Permutation parity is equivalent to the determinant. 234 235 def _log_abs_determinant(self): 236 # Permutation matrices have determinant +/- 1. 237 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 238 239 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 240 # The inverse of a permutation matrix is the transpose matrix. 241 # Apply a matmul and flip the adjoint bit. 242 return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg) 243 244 def _to_dense(self): 245 perm = ops.convert_to_tensor_v2_with_dispatch(self.perm) 246 return math_ops.cast(math_ops.equal( 247 math_ops.range(0, self._domain_dimension_tensor(perm)), 248 perm[..., array_ops.newaxis]), self.dtype) 249 250 def _diag_part(self): 251 perm = ops.convert_to_tensor_v2_with_dispatch(self.perm) 252 return math_ops.cast(math_ops.equal( 253 math_ops.range(0, self._domain_dimension_tensor(perm)), 254 perm), self.dtype) 255 256 def _cond(self): 257 # Permutation matrices are rotations which have condition number 1. 258 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 259 260 @property 261 def perm(self): 262 return self._perm 263 264 @property 265 def _composite_tensor_fields(self): 266 return ("perm", "dtype") 267 268 @property 269 def _experimental_parameter_ndims_to_matrix_ndims(self): 270 return {"perm": 1} 271