1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a Householder transformation.""" 16 17from tensorflow.python.framework import errors 18from tensorflow.python.framework import ops 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import control_flow_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops import nn 23from tensorflow.python.ops.linalg import linalg_impl as linalg 24from tensorflow.python.ops.linalg import linear_operator 25from tensorflow.python.ops.linalg import linear_operator_util 26from tensorflow.python.util.tf_export import tf_export 27 28__all__ = ["LinearOperatorHouseholder",] 29 30 31@tf_export("linalg.LinearOperatorHouseholder") 32@linear_operator.make_composite_tensor 33class LinearOperatorHouseholder(linear_operator.LinearOperator): 34 """`LinearOperator` acting like a [batch] of Householder transformations. 35 36 This operator acts like a [batch] of householder reflections with shape 37 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 38 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 39 an `N x N` matrix. This matrix `A` is not materialized, but for 40 purposes of broadcasting this shape will be relevant. 41 42 `LinearOperatorHouseholder` is initialized with a (batch) vector. 43 44 A Householder reflection, defined via a vector `v`, which reflects points 45 in `R^n` about the hyperplane orthogonal to `v` and through the origin. 46 47 ```python 48 # Create a 2 x 2 householder transform. 49 vec = [1 / np.sqrt(2), 1. / np.sqrt(2)] 50 operator = LinearOperatorHouseholder(vec) 51 52 operator.to_dense() 53 ==> [[0., -1.] 54 [-1., -0.]] 55 56 operator.shape 57 ==> [2, 2] 58 59 operator.log_abs_determinant() 60 ==> scalar Tensor 61 62 x = ... Shape [2, 4] Tensor 63 operator.matmul(x) 64 ==> Shape [2, 4] Tensor 65 ``` 66 67 #### Shape compatibility 68 69 This operator acts on [batch] matrix with compatible shape. 70 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 71 72 ``` 73 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 74 x.shape = [C1,...,Cc] + [N, R], 75 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 76 ``` 77 78 #### Matrix property hints 79 80 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 81 for `X = non_singular, self_adjoint, positive_definite, square`. 82 These have the following meaning: 83 84 * If `is_X == True`, callers should expect the operator to have the 85 property `X`. This is a promise that should be fulfilled, but is *not* a 86 runtime assert. For example, finite floating point precision may result 87 in these promises being violated. 88 * If `is_X == False`, callers should expect the operator to not have `X`. 89 * If `is_X == None` (the default), callers should have no expectation either 90 way. 91 """ 92 93 def __init__(self, 94 reflection_axis, 95 is_non_singular=None, 96 is_self_adjoint=None, 97 is_positive_definite=None, 98 is_square=None, 99 name="LinearOperatorHouseholder"): 100 r"""Initialize a `LinearOperatorHouseholder`. 101 102 Args: 103 reflection_axis: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 104 The vector defining the hyperplane to reflect about. 105 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, 106 `complex128`. 107 is_non_singular: Expect that this operator is non-singular. 108 is_self_adjoint: Expect that this operator is equal to its hermitian 109 transpose. This is autoset to true 110 is_positive_definite: Expect that this operator is positive definite, 111 meaning the quadratic form `x^H A x` has positive real part for all 112 nonzero `x`. Note that we do not require the operator to be 113 self-adjoint to be positive-definite. See: 114 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 115 This is autoset to false. 116 is_square: Expect that this operator acts like square [batch] matrices. 117 This is autoset to true. 118 name: A name for this `LinearOperator`. 119 120 Raises: 121 ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is 122 not `False` or `is_square` is not `True`. 123 """ 124 parameters = dict( 125 reflection_axis=reflection_axis, 126 is_non_singular=is_non_singular, 127 is_self_adjoint=is_self_adjoint, 128 is_positive_definite=is_positive_definite, 129 is_square=is_square, 130 name=name 131 ) 132 133 with ops.name_scope(name, values=[reflection_axis]): 134 self._reflection_axis = linear_operator_util.convert_nonref_to_tensor( 135 reflection_axis, name="reflection_axis") 136 self._check_reflection_axis(self._reflection_axis) 137 138 # Check and auto-set hints. 139 if is_self_adjoint is False: # pylint:disable=g-bool-id-comparison 140 raise ValueError("A Householder operator is always self adjoint.") 141 else: 142 is_self_adjoint = True 143 144 if is_positive_definite is True: # pylint:disable=g-bool-id-comparison 145 raise ValueError( 146 "A Householder operator is always non-positive definite.") 147 else: 148 is_positive_definite = False 149 150 if is_square is False: # pylint:disable=g-bool-id-comparison 151 raise ValueError("A Householder operator is always square.") 152 is_square = True 153 154 super(LinearOperatorHouseholder, self).__init__( 155 dtype=self._reflection_axis.dtype, 156 is_non_singular=is_non_singular, 157 is_self_adjoint=is_self_adjoint, 158 is_positive_definite=is_positive_definite, 159 is_square=is_square, 160 parameters=parameters, 161 name=name) 162 163 def _check_reflection_axis(self, reflection_axis): 164 """Static check of reflection_axis.""" 165 if (reflection_axis.shape.ndims is not None and 166 reflection_axis.shape.ndims < 1): 167 raise ValueError( 168 "Argument reflection_axis must have at least 1 dimension. " 169 "Found: %s" % reflection_axis) 170 171 def _shape(self): 172 # If d_shape = [5, 3], we return [5, 3, 3]. 173 d_shape = self._reflection_axis.shape 174 return d_shape.concatenate(d_shape[-1:]) 175 176 def _shape_tensor(self): 177 d_shape = array_ops.shape(self._reflection_axis) 178 k = d_shape[-1] 179 return array_ops.concat((d_shape, [k]), 0) 180 181 def _assert_non_singular(self): 182 return control_flow_ops.no_op("assert_non_singular") 183 184 def _assert_positive_definite(self): 185 raise errors.InvalidArgumentError( 186 node_def=None, op=None, message="Householder operators are always " 187 "non-positive definite.") 188 189 def _assert_self_adjoint(self): 190 return control_flow_ops.no_op("assert_self_adjoint") 191 192 def _matmul(self, x, adjoint=False, adjoint_arg=False): 193 # Given a vector `v`, we would like to reflect `x` about the hyperplane 194 # orthogonal to `v` going through the origin. We first project `x` to `v` 195 # to get v * dot(v, x) / dot(v, v). After we project, we can reflect the 196 # projection about the hyperplane by flipping sign to get 197 # -v * dot(v, x) / dot(v, v). Finally, we can add back the component 198 # that is orthogonal to v. This is invariant under reflection, since the 199 # whole hyperplane is invariant. This component is equal to x - v * dot(v, 200 # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v) 201 # for the reflection. 202 203 # Note that because this is a reflection, it lies in O(n) (for real vector 204 # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint. 205 reflection_axis = ops.convert_to_tensor_v2_with_dispatch( 206 self.reflection_axis) 207 x = linalg.adjoint(x) if adjoint_arg else x 208 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) 209 mat = normalized_axis[..., array_ops.newaxis] 210 x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True) 211 212 return x - 2 * mat * x_dot_normalized_v 213 214 def _trace(self): 215 # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue. 216 shape = self.shape_tensor() 217 return math_ops.cast( 218 self._domain_dimension_tensor(shape=shape) - 2, 219 self.dtype) * array_ops.ones( 220 shape=self._batch_shape_tensor(shape=shape), dtype=self.dtype) 221 222 def _determinant(self): 223 # For householder transformations, the determinant is -1. 224 return -array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) # pylint: disable=invalid-unary-operand-type 225 226 def _log_abs_determinant(self): 227 # Orthogonal matrix -> log|Q| = 0. 228 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 229 230 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 231 # A householder reflection is a reflection, hence is idempotent. Thus we 232 # can just apply a matmul. 233 return self._matmul(rhs, adjoint, adjoint_arg) 234 235 def _to_dense(self): 236 reflection_axis = ops.convert_to_tensor_v2_with_dispatch( 237 self.reflection_axis) 238 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) 239 mat = normalized_axis[..., array_ops.newaxis] 240 matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True) 241 return array_ops.matrix_set_diag( 242 matrix, 1. + array_ops.matrix_diag_part(matrix)) 243 244 def _diag_part(self): 245 reflection_axis = ops.convert_to_tensor_v2_with_dispatch( 246 self.reflection_axis) 247 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) 248 return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis) 249 250 def _eigvals(self): 251 # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue. 252 result_shape = array_ops.shape(self.reflection_axis) 253 n = result_shape[-1] 254 ones_shape = array_ops.concat([result_shape[:-1], [n - 1]], axis=-1) 255 neg_shape = array_ops.concat([result_shape[:-1], [1]], axis=-1) 256 eigvals = array_ops.ones(shape=ones_shape, dtype=self.dtype) 257 eigvals = array_ops.concat( 258 [-array_ops.ones(shape=neg_shape, dtype=self.dtype), eigvals], axis=-1) # pylint: disable=invalid-unary-operand-type 259 return eigvals 260 261 def _cond(self): 262 # Householder matrices are rotations which have condition number 1. 263 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 264 265 @property 266 def reflection_axis(self): 267 return self._reflection_axis 268 269 @property 270 def _composite_tensor_fields(self): 271 return ("reflection_axis",) 272 273 @property 274 def _experimental_parameter_ndims_to_matrix_ndims(self): 275 return {"reflection_axis": 1} 276