1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a Toeplitz matrix.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import check_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops.linalg import linalg_impl as linalg 23from tensorflow.python.ops.linalg import linear_operator 24from tensorflow.python.ops.linalg import linear_operator_circulant 25from tensorflow.python.ops.linalg import linear_operator_util 26from tensorflow.python.ops.signal import fft_ops 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = ["LinearOperatorToeplitz",] 30 31 32@tf_export("linalg.LinearOperatorToeplitz") 33@linear_operator.make_composite_tensor 34class LinearOperatorToeplitz(linear_operator.LinearOperator): 35 """`LinearOperator` acting like a [batch] of toeplitz matrices. 36 37 This operator acts like a [batch] Toeplitz matrix `A` with shape 38 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 39 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 40 an `N x N` matrix. This matrix `A` is not materialized, but for 41 purposes of broadcasting this shape will be relevant. 42 43 #### Description in terms of toeplitz matrices 44 45 Toeplitz means that `A` has constant diagonals. Hence, `A` can be generated 46 with two vectors. One represents the first column of the matrix, and the 47 other represents the first row. 48 49 Below is a 4 x 4 example: 50 51 ``` 52 A = |a b c d| 53 |e a b c| 54 |f e a b| 55 |g f e a| 56 ``` 57 58 #### Example of a Toeplitz operator. 59 60 ```python 61 # Create a 3 x 3 Toeplitz operator. 62 col = [1., 2., 3.] 63 row = [1., 4., -9.] 64 operator = LinearOperatorToeplitz(col, row) 65 66 operator.to_dense() 67 ==> [[1., 4., -9.], 68 [2., 1., 4.], 69 [3., 2., 1.]] 70 71 operator.shape 72 ==> [3, 3] 73 74 operator.log_abs_determinant() 75 ==> scalar Tensor 76 77 x = ... Shape [3, 4] Tensor 78 operator.matmul(x) 79 ==> Shape [3, 4] Tensor 80 ``` 81 82 #### Shape compatibility 83 84 This operator acts on [batch] matrix with compatible shape. 85 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 86 87 ``` 88 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 89 x.shape = [C1,...,Cc] + [N, R], 90 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 91 ``` 92 93 #### Matrix property hints 94 95 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 96 for `X = non_singular, self_adjoint, positive_definite, square`. 97 These have the following meaning: 98 99 * If `is_X == True`, callers should expect the operator to have the 100 property `X`. This is a promise that should be fulfilled, but is *not* a 101 runtime assert. For example, finite floating point precision may result 102 in these promises being violated. 103 * If `is_X == False`, callers should expect the operator to not have `X`. 104 * If `is_X == None` (the default), callers should have no expectation either 105 way. 106 """ 107 108 def __init__(self, 109 col, 110 row, 111 is_non_singular=None, 112 is_self_adjoint=None, 113 is_positive_definite=None, 114 is_square=None, 115 name="LinearOperatorToeplitz"): 116 r"""Initialize a `LinearOperatorToeplitz`. 117 118 Args: 119 col: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 120 The first column of the operator. Allowed dtypes: `float16`, `float32`, 121 `float64`, `complex64`, `complex128`. Note that the first entry of 122 `col` is assumed to be the same as the first entry of `row`. 123 row: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 124 The first row of the operator. Allowed dtypes: `float16`, `float32`, 125 `float64`, `complex64`, `complex128`. Note that the first entry of 126 `row` is assumed to be the same as the first entry of `col`. 127 is_non_singular: Expect that this operator is non-singular. 128 is_self_adjoint: Expect that this operator is equal to its hermitian 129 transpose. If `diag.dtype` is real, this is auto-set to `True`. 130 is_positive_definite: Expect that this operator is positive definite, 131 meaning the quadratic form `x^H A x` has positive real part for all 132 nonzero `x`. Note that we do not require the operator to be 133 self-adjoint to be positive-definite. See: 134 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 135 is_square: Expect that this operator acts like square [batch] matrices. 136 name: A name for this `LinearOperator`. 137 """ 138 parameters = dict( 139 col=col, 140 row=row, 141 is_non_singular=is_non_singular, 142 is_self_adjoint=is_self_adjoint, 143 is_positive_definite=is_positive_definite, 144 is_square=is_square, 145 name=name 146 ) 147 148 with ops.name_scope(name, values=[row, col]): 149 self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row") 150 self._col = linear_operator_util.convert_nonref_to_tensor(col, name="col") 151 self._check_row_col(self._row, self._col) 152 153 if is_square is False: # pylint:disable=g-bool-id-comparison 154 raise ValueError("Only square Toeplitz operators currently supported.") 155 is_square = True 156 157 super(LinearOperatorToeplitz, self).__init__( 158 dtype=self._row.dtype, 159 is_non_singular=is_non_singular, 160 is_self_adjoint=is_self_adjoint, 161 is_positive_definite=is_positive_definite, 162 is_square=is_square, 163 parameters=parameters, 164 name=name) 165 166 def _check_row_col(self, row, col): 167 """Static check of row and column.""" 168 for name, tensor in [["row", row], ["col", col]]: 169 if tensor.shape.ndims is not None and tensor.shape.ndims < 1: 170 raise ValueError("Argument {} must have at least 1 dimension. " 171 "Found: {}".format(name, tensor)) 172 173 if row.shape[-1] is not None and col.shape[-1] is not None: 174 if row.shape[-1] != col.shape[-1]: 175 raise ValueError( 176 "Expected square matrix, got row and col with mismatched " 177 "dimensions.") 178 179 def _shape(self): 180 # If d_shape = [5, 3], we return [5, 3, 3]. 181 v_shape = array_ops.broadcast_static_shape( 182 self.row.shape, self.col.shape) 183 return v_shape.concatenate(v_shape[-1:]) 184 185 def _shape_tensor(self, row=None, col=None): 186 row = self.row if row is None else row 187 col = self.col if col is None else col 188 v_shape = array_ops.broadcast_dynamic_shape( 189 array_ops.shape(row), 190 array_ops.shape(col)) 191 k = v_shape[-1] 192 return array_ops.concat((v_shape, [k]), 0) 193 194 def _assert_self_adjoint(self): 195 return check_ops.assert_equal( 196 self.row, 197 self.col, 198 message=("row and col are not the same, and " 199 "so this operator is not self-adjoint.")) 200 201 # TODO(srvasude): Add efficient solver and determinant calculations to this 202 # class (based on Levinson recursion.) 203 204 def _matmul(self, x, adjoint=False, adjoint_arg=False): 205 # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform 206 # efficient matrix multiplications. Given a Toeplitz matrix with first row 207 # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)}, 208 # let C by the circulant matrix with first column [t0, t_{-1}, ..., 209 # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x` 210 # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown 211 # that if we take the first n entries of `Cy`, this is equal to the Toeplitz 212 # multiplication. See: 213 # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf 214 # for more details. 215 x = linalg.adjoint(x) if adjoint_arg else x 216 expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2) 217 col = ops.convert_to_tensor_v2_with_dispatch(self.col) 218 row = ops.convert_to_tensor_v2_with_dispatch(self.row) 219 circulant_col = array_ops.concat( 220 [col, 221 array_ops.zeros_like(col[..., 0:1]), 222 array_ops.reverse(row[..., 1:], axis=[-1])], axis=-1) 223 circulant = linear_operator_circulant.LinearOperatorCirculant( 224 fft_ops.fft(_to_complex(circulant_col)), 225 input_output_dtype=row.dtype) 226 result = circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False) 227 228 shape = self._shape_tensor(row=row, col=col) 229 return math_ops.cast( 230 result[..., :self._domain_dimension_tensor(shape=shape), :], 231 self.dtype) 232 233 def _trace(self): 234 return math_ops.cast( 235 self.domain_dimension_tensor(), 236 dtype=self.dtype) * self.col[..., 0] 237 238 def _diag_part(self): 239 diag_entry = self.col[..., 0:1] 240 return diag_entry * array_ops.ones( 241 [self.domain_dimension_tensor()], self.dtype) 242 243 def _to_dense(self): 244 row = ops.convert_to_tensor_v2_with_dispatch(self.row) 245 col = ops.convert_to_tensor_v2_with_dispatch(self.col) 246 total_shape = array_ops.broadcast_dynamic_shape( 247 array_ops.shape(row), array_ops.shape(col)) 248 n = array_ops.shape(row)[-1] 249 row = array_ops.broadcast_to(row, total_shape) 250 col = array_ops.broadcast_to(col, total_shape) 251 # We concatenate the column in reverse order to the row. 252 # This gives us 2*n + 1 elements. 253 elements = array_ops.concat( 254 [array_ops.reverse(col, axis=[-1]), row[..., 1:]], axis=-1) 255 # Given the above vector, the i-th row of the Toeplitz matrix 256 # is the last n elements of the above vector shifted i right 257 # (hence the first row is just the row vector provided, and 258 # the first element of each row will belong to the column vector). 259 # We construct these set of indices below. 260 indices = math_ops.mod( 261 # How much to shift right. This corresponds to `i`. 262 math_ops.range(0, n) + 263 # Specifies the last `n` indices. 264 math_ops.range(n - 1, -1, -1)[..., array_ops.newaxis], 265 # Mod out by the total number of elements to ensure the index is 266 # non-negative (for tf.gather) and < 2 * n - 1. 267 2 * n - 1) 268 return array_ops.gather(elements, indices, axis=-1) 269 270 @property 271 def col(self): 272 return self._col 273 274 @property 275 def row(self): 276 return self._row 277 278 @property 279 def _composite_tensor_fields(self): 280 return ("col", "row") 281 282 @property 283 def _experimental_parameter_ndims_to_matrix_ndims(self): 284 return {"col": 1, "row": 1} 285 286 287def _to_complex(x): 288 dtype = dtypes.complex64 289 if x.dtype in [dtypes.float64, dtypes.complex128]: 290 dtype = dtypes.complex128 291 return math_ops.cast(x, dtype) 292