• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a permutation matrix."""
16
17import numpy as np
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import sort_ops
26from tensorflow.python.ops.linalg import linalg_impl as linalg
27from tensorflow.python.ops.linalg import linear_operator
28from tensorflow.python.ops.linalg import linear_operator_util
29from tensorflow.python.util.tf_export import tf_export
30
31__all__ = ["LinearOperatorPermutation",]
32
33
34@tf_export("linalg.LinearOperatorPermutation")
35@linear_operator.make_composite_tensor
36class LinearOperatorPermutation(linear_operator.LinearOperator):
37  """`LinearOperator` acting like a [batch] of permutation matrices.
38
39  This operator acts like a [batch] of permutations with shape
40  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
41  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
42  an `N x N` matrix.  This matrix `A` is not materialized, but for
43  purposes of broadcasting this shape will be relevant.
44
45  `LinearOperatorPermutation` is initialized with a (batch) vector.
46
47  A permutation, is defined by an integer vector `v` whose values are unique
48  and are in the range `[0, ... n]`. Applying the permutation on an input
49  matrix has the folllowing meaning: the value of `v` at index `i`
50  says to move the `v[i]`-th row of the input matrix to the `i`-th row.
51  Because all values are unique, this will result in a permutation of the
52  rows the input matrix. Note, that the permutation vector `v` has the same
53  semantics as `tf.transpose`.
54
55  ```python
56  # Create a 3 x 3 permutation matrix that swaps the last two columns.
57  vec = [0, 2, 1]
58  operator = LinearOperatorPermutation(vec)
59
60  operator.to_dense()
61  ==> [[1., 0., 0.]
62       [0., 0., 1.]
63       [0., 1., 0.]]
64
65  operator.shape
66  ==> [3, 3]
67
68  # This will be zero.
69  operator.log_abs_determinant()
70  ==> scalar Tensor
71
72  x = ... Shape [3, 4] Tensor
73  operator.matmul(x)
74  ==> Shape [3, 4] Tensor
75  ```
76
77  #### Shape compatibility
78
79  This operator acts on [batch] matrix with compatible shape.
80  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
81
82  ```
83  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
84  x.shape =   [C1,...,Cc] + [N, R],
85  and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
86  ```
87
88  #### Matrix property hints
89
90  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
91  for `X = non_singular, self_adjoint, positive_definite, square`.
92  These have the following meaning:
93
94  * If `is_X == True`, callers should expect the operator to have the
95    property `X`.  This is a promise that should be fulfilled, but is *not* a
96    runtime assert.  For example, finite floating point precision may result
97    in these promises being violated.
98  * If `is_X == False`, callers should expect the operator to not have `X`.
99  * If `is_X == None` (the default), callers should have no expectation either
100    way.
101  """
102
103  def __init__(self,
104               perm,
105               dtype=dtypes.float32,
106               is_non_singular=None,
107               is_self_adjoint=None,
108               is_positive_definite=None,
109               is_square=None,
110               name="LinearOperatorPermutation"):
111    r"""Initialize a `LinearOperatorPermutation`.
112
113    Args:
114      perm:  Shape `[B1,...,Bb, N]` Integer `Tensor` with `b >= 0`
115        `N >= 0`. An integer vector that represents the permutation to apply.
116        Note that this argument is same as `tf.transpose`. However, this
117        permutation is applied on the rows, while the permutation in
118        `tf.transpose` is applied on the dimensions of the `Tensor`. `perm`
119        is required to have unique entries from `{0, 1, ... N-1}`.
120      dtype: The `dtype` of arguments to this operator. Default: `float32`.
121        Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
122        `complex128`.
123      is_non_singular:  Expect that this operator is non-singular.
124      is_self_adjoint:  Expect that this operator is equal to its hermitian
125        transpose.  This is autoset to true
126      is_positive_definite:  Expect that this operator is positive definite,
127        meaning the quadratic form `x^H A x` has positive real part for all
128        nonzero `x`.  Note that we do not require the operator to be
129        self-adjoint to be positive-definite.  See:
130        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
131        This is autoset to false.
132      is_square:  Expect that this operator acts like square [batch] matrices.
133        This is autoset to true.
134      name: A name for this `LinearOperator`.
135
136    Raises:
137      ValueError:  `is_self_adjoint` is not `True`, `is_positive_definite` is
138        not `False` or `is_square` is not `True`.
139    """
140    parameters = dict(
141        perm=perm,
142        dtype=dtype,
143        is_non_singular=is_non_singular,
144        is_self_adjoint=is_self_adjoint,
145        is_positive_definite=is_positive_definite,
146        is_square=is_square,
147        name=name
148    )
149
150    with ops.name_scope(name, values=[perm]):
151      self._perm = linear_operator_util.convert_nonref_to_tensor(
152          perm, name="perm")
153      self._check_perm(self._perm)
154
155      # Check and auto-set hints.
156      if is_non_singular is False:  # pylint:disable=g-bool-id-comparison
157        raise ValueError(f"A Permutation operator is always non-singular. "
158                         f"Expected argument `is_non_singular` to be True. "
159                         f"Received: {is_non_singular}.")
160
161      if is_square is False:  # pylint:disable=g-bool-id-comparison
162        raise ValueError(f"A Permutation operator is always square. "
163                         f"Expected argument `is_square` to be True. "
164                         f"Received: {is_square}.")
165      is_square = True
166
167      super(LinearOperatorPermutation, self).__init__(
168          dtype=dtype,
169          is_non_singular=is_non_singular,
170          is_self_adjoint=is_self_adjoint,
171          is_positive_definite=is_positive_definite,
172          is_square=is_square,
173          parameters=parameters,
174          name=name)
175
176  def _check_perm(self, perm):
177    """Static check of perm."""
178    if (perm.shape.ndims is not None and perm.shape.ndims < 1):
179      raise ValueError(f"Argument `perm` must have at least 1 dimension. "
180                       f"Received: {perm}.")
181    if not perm.dtype.is_integer:
182      raise TypeError(f"Argument `perm` must be integer dtype. "
183                      f"Received: {perm}.")
184    # Check that the permutation satisfies the uniqueness constraint.
185    static_perm = tensor_util.constant_value(perm)
186    if static_perm is not None:
187      sorted_perm = np.sort(static_perm, axis=-1)
188      if np.any(sorted_perm != np.arange(0, static_perm.shape[-1])):
189        raise ValueError(
190            f"Argument `perm` must be a vector of unique integers from "
191            f"0 to {static_perm.shape[-1] - 1}.")
192
193  def _shape(self):
194    perm_shape = self._perm.shape
195    return perm_shape.concatenate(perm_shape[-1:])
196
197  def _shape_tensor(self):
198    perm_shape = array_ops.shape(self._perm)
199    k = perm_shape[-1]
200    return array_ops.concat((perm_shape, [k]), 0)
201
202  def _assert_non_singular(self):
203    return control_flow_ops.no_op("assert_non_singular")
204
205  def _domain_dimension_tensor(self, perm=None):
206    perm = perm if perm is not None else self.perm
207    return array_ops.shape(perm)[-1]
208
209  def _matmul(self, x, adjoint=False, adjoint_arg=False):
210    perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
211    if adjoint and not self.is_self_adjoint:
212      # TODO(srvasude): invert_permutation doesn't work on batches so we use
213      # argsort.
214      perm = sort_ops.argsort(perm, axis=-1)
215    x = linalg.adjoint(x) if adjoint_arg else x
216
217    # We need to broadcast x and the permutation since tf.gather doesn't
218    # broadcast.
219    broadcast_shape = array_ops.broadcast_dynamic_shape(
220        array_ops.shape(x)[:-1], array_ops.shape(perm))
221    k = array_ops.shape(x)[-1]
222    broadcast_x_shape = array_ops.concat([broadcast_shape, [k]], axis=-1)
223    x = array_ops.broadcast_to(x, broadcast_x_shape)
224    perm = array_ops.broadcast_to(perm, broadcast_shape)
225
226    m = array_ops.shape(x)[-2]
227    x = array_ops.reshape(x, [-1, m, k])
228    perm = array_ops.reshape(perm, [-1, m])
229
230    y = array_ops.gather(x, perm, axis=-2, batch_dims=1)
231    return array_ops.reshape(y, broadcast_x_shape)
232
233  # TODO(srvasude): Permutation parity is equivalent to the determinant.
234
235  def _log_abs_determinant(self):
236    # Permutation matrices have determinant +/- 1.
237    return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
238
239  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
240    # The inverse of a permutation matrix is the transpose matrix.
241    # Apply a matmul and flip the adjoint bit.
242    return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
243
244  def _to_dense(self):
245    perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
246    return math_ops.cast(math_ops.equal(
247        math_ops.range(0, self._domain_dimension_tensor(perm)),
248        perm[..., array_ops.newaxis]), self.dtype)
249
250  def _diag_part(self):
251    perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
252    return math_ops.cast(math_ops.equal(
253        math_ops.range(0, self._domain_dimension_tensor(perm)),
254        perm), self.dtype)
255
256  def _cond(self):
257    # Permutation matrices are rotations which have condition number 1.
258    return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
259
260  @property
261  def perm(self):
262    return self._perm
263
264  @property
265  def _composite_tensor_fields(self):
266    return ("perm", "dtype")
267
268  @property
269  def _experimental_parameter_ndims_to_matrix_ndims(self):
270    return {"perm": 1}
271