• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a Householder transformation."""
16
17from tensorflow.python.framework import errors
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import control_flow_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import nn
23from tensorflow.python.ops.linalg import linalg_impl as linalg
24from tensorflow.python.ops.linalg import linear_operator
25from tensorflow.python.ops.linalg import linear_operator_util
26from tensorflow.python.util.tf_export import tf_export
27
28__all__ = ["LinearOperatorHouseholder",]
29
30
31@tf_export("linalg.LinearOperatorHouseholder")
32@linear_operator.make_composite_tensor
33class LinearOperatorHouseholder(linear_operator.LinearOperator):
34  """`LinearOperator` acting like a [batch] of Householder transformations.
35
36  This operator acts like a [batch] of householder reflections with shape
37  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
38  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
39  an `N x N` matrix.  This matrix `A` is not materialized, but for
40  purposes of broadcasting this shape will be relevant.
41
42  `LinearOperatorHouseholder` is initialized with a (batch) vector.
43
44  A Householder reflection, defined via a vector `v`, which reflects points
45  in `R^n` about the hyperplane orthogonal to `v` and through the origin.
46
47  ```python
48  # Create a 2 x 2 householder transform.
49  vec = [1 / np.sqrt(2), 1. / np.sqrt(2)]
50  operator = LinearOperatorHouseholder(vec)
51
52  operator.to_dense()
53  ==> [[0.,  -1.]
54       [-1., -0.]]
55
56  operator.shape
57  ==> [2, 2]
58
59  operator.log_abs_determinant()
60  ==> scalar Tensor
61
62  x = ... Shape [2, 4] Tensor
63  operator.matmul(x)
64  ==> Shape [2, 4] Tensor
65  ```
66
67  #### Shape compatibility
68
69  This operator acts on [batch] matrix with compatible shape.
70  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
71
72  ```
73  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
74  x.shape =   [C1,...,Cc] + [N, R],
75  and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
76  ```
77
78  #### Matrix property hints
79
80  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
81  for `X = non_singular, self_adjoint, positive_definite, square`.
82  These have the following meaning:
83
84  * If `is_X == True`, callers should expect the operator to have the
85    property `X`.  This is a promise that should be fulfilled, but is *not* a
86    runtime assert.  For example, finite floating point precision may result
87    in these promises being violated.
88  * If `is_X == False`, callers should expect the operator to not have `X`.
89  * If `is_X == None` (the default), callers should have no expectation either
90    way.
91  """
92
93  def __init__(self,
94               reflection_axis,
95               is_non_singular=None,
96               is_self_adjoint=None,
97               is_positive_definite=None,
98               is_square=None,
99               name="LinearOperatorHouseholder"):
100    r"""Initialize a `LinearOperatorHouseholder`.
101
102    Args:
103      reflection_axis:  Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
104        The vector defining the hyperplane to reflect about.
105        Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
106        `complex128`.
107      is_non_singular:  Expect that this operator is non-singular.
108      is_self_adjoint:  Expect that this operator is equal to its hermitian
109        transpose.  This is autoset to true
110      is_positive_definite:  Expect that this operator is positive definite,
111        meaning the quadratic form `x^H A x` has positive real part for all
112        nonzero `x`.  Note that we do not require the operator to be
113        self-adjoint to be positive-definite.  See:
114        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
115        This is autoset to false.
116      is_square:  Expect that this operator acts like square [batch] matrices.
117        This is autoset to true.
118      name: A name for this `LinearOperator`.
119
120    Raises:
121      ValueError:  `is_self_adjoint` is not `True`, `is_positive_definite` is
122        not `False` or `is_square` is not `True`.
123    """
124    parameters = dict(
125        reflection_axis=reflection_axis,
126        is_non_singular=is_non_singular,
127        is_self_adjoint=is_self_adjoint,
128        is_positive_definite=is_positive_definite,
129        is_square=is_square,
130        name=name
131    )
132
133    with ops.name_scope(name, values=[reflection_axis]):
134      self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
135          reflection_axis, name="reflection_axis")
136      self._check_reflection_axis(self._reflection_axis)
137
138      # Check and auto-set hints.
139      if is_self_adjoint is False:  # pylint:disable=g-bool-id-comparison
140        raise ValueError("A Householder operator is always self adjoint.")
141      else:
142        is_self_adjoint = True
143
144      if is_positive_definite is True:  # pylint:disable=g-bool-id-comparison
145        raise ValueError(
146            "A Householder operator is always non-positive definite.")
147      else:
148        is_positive_definite = False
149
150      if is_square is False:  # pylint:disable=g-bool-id-comparison
151        raise ValueError("A Householder operator is always square.")
152      is_square = True
153
154      super(LinearOperatorHouseholder, self).__init__(
155          dtype=self._reflection_axis.dtype,
156          is_non_singular=is_non_singular,
157          is_self_adjoint=is_self_adjoint,
158          is_positive_definite=is_positive_definite,
159          is_square=is_square,
160          parameters=parameters,
161          name=name)
162
163  def _check_reflection_axis(self, reflection_axis):
164    """Static check of reflection_axis."""
165    if (reflection_axis.shape.ndims is not None and
166        reflection_axis.shape.ndims < 1):
167      raise ValueError(
168          "Argument reflection_axis must have at least 1 dimension.  "
169          "Found: %s" % reflection_axis)
170
171  def _shape(self):
172    # If d_shape = [5, 3], we return [5, 3, 3].
173    d_shape = self._reflection_axis.shape
174    return d_shape.concatenate(d_shape[-1:])
175
176  def _shape_tensor(self):
177    d_shape = array_ops.shape(self._reflection_axis)
178    k = d_shape[-1]
179    return array_ops.concat((d_shape, [k]), 0)
180
181  def _assert_non_singular(self):
182    return control_flow_ops.no_op("assert_non_singular")
183
184  def _assert_positive_definite(self):
185    raise errors.InvalidArgumentError(
186        node_def=None, op=None, message="Householder operators are always "
187        "non-positive definite.")
188
189  def _assert_self_adjoint(self):
190    return control_flow_ops.no_op("assert_self_adjoint")
191
192  def _matmul(self, x, adjoint=False, adjoint_arg=False):
193    # Given a vector `v`, we would like to reflect `x` about the hyperplane
194    # orthogonal to `v` going through the origin.  We first project `x` to `v`
195    # to get v * dot(v, x) / dot(v, v).  After we project, we can reflect the
196    # projection about the hyperplane by flipping sign to get
197    # -v * dot(v, x) / dot(v, v).  Finally, we can add back the component
198    # that is orthogonal to v. This is invariant under reflection, since the
199    # whole hyperplane is invariant. This component is equal to x - v * dot(v,
200    # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v)
201    # for the reflection.
202
203    # Note that because this is a reflection, it lies in O(n) (for real vector
204    # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
205    reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
206        self.reflection_axis)
207    x = linalg.adjoint(x) if adjoint_arg else x
208    normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
209    mat = normalized_axis[..., array_ops.newaxis]
210    x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True)
211
212    return x - 2 * mat * x_dot_normalized_v
213
214  def _trace(self):
215    # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue.
216    shape = self.shape_tensor()
217    return math_ops.cast(
218        self._domain_dimension_tensor(shape=shape) - 2,
219        self.dtype) * array_ops.ones(
220            shape=self._batch_shape_tensor(shape=shape), dtype=self.dtype)
221
222  def _determinant(self):
223    # For householder transformations, the determinant is -1.
224    return -array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)  # pylint: disable=invalid-unary-operand-type
225
226  def _log_abs_determinant(self):
227    # Orthogonal matrix -> log|Q| = 0.
228    return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
229
230  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
231    # A householder reflection is a reflection, hence is idempotent. Thus we
232    # can just apply a matmul.
233    return self._matmul(rhs, adjoint, adjoint_arg)
234
235  def _to_dense(self):
236    reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
237        self.reflection_axis)
238    normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
239    mat = normalized_axis[..., array_ops.newaxis]
240    matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True)
241    return array_ops.matrix_set_diag(
242        matrix, 1. + array_ops.matrix_diag_part(matrix))
243
244  def _diag_part(self):
245    reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
246        self.reflection_axis)
247    normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
248    return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
249
250  def _eigvals(self):
251    # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue.
252    result_shape = array_ops.shape(self.reflection_axis)
253    n = result_shape[-1]
254    ones_shape = array_ops.concat([result_shape[:-1], [n - 1]], axis=-1)
255    neg_shape = array_ops.concat([result_shape[:-1], [1]], axis=-1)
256    eigvals = array_ops.ones(shape=ones_shape, dtype=self.dtype)
257    eigvals = array_ops.concat(
258        [-array_ops.ones(shape=neg_shape, dtype=self.dtype), eigvals], axis=-1)  # pylint: disable=invalid-unary-operand-type
259    return eigvals
260
261  def _cond(self):
262    # Householder matrices are rotations which have condition number 1.
263    return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
264
265  @property
266  def reflection_axis(self):
267    return self._reflection_axis
268
269  @property
270  def _composite_tensor_fields(self):
271    return ("reflection_axis",)
272
273  @property
274  def _experimental_parameter_ndims_to_matrix_ndims(self):
275    return {"reflection_axis": 1}
276