• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a zero matrix."""
16
17import numpy as np
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import errors
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops.linalg import linalg_impl as linalg
29from tensorflow.python.ops.linalg import linear_operator
30from tensorflow.python.ops.linalg import linear_operator_util
31from tensorflow.python.util.tf_export import tf_export
32
33__all__ = [
34    "LinearOperatorZeros",
35]
36
37
38@tf_export("linalg.LinearOperatorZeros")
39@linear_operator.make_composite_tensor
40class LinearOperatorZeros(linear_operator.LinearOperator):
41  """`LinearOperator` acting like a [batch] zero matrix.
42
43  This operator acts like a [batch] zero matrix `A` with shape
44  `[B1,...,Bb, N, M]` for some `b >= 0`.  The first `b` indices index a
45  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
46  an `N x M` matrix.  This matrix `A` is not materialized, but for
47  purposes of broadcasting this shape will be relevant.
48
49  `LinearOperatorZeros` is initialized with `num_rows`, and optionally
50  `num_columns, `batch_shape`, and `dtype` arguments.  If `num_columns` is
51  `None`, then this operator will be initialized as a square matrix. If
52  `batch_shape` is `None`, this operator efficiently passes through all
53  arguments.  If `batch_shape` is provided, broadcasting may occur, which will
54  require making copies.
55
56  ```python
57  # Create a 2 x 2 zero matrix.
58  operator = LinearOperatorZero(num_rows=2, dtype=tf.float32)
59
60  operator.to_dense()
61  ==> [[0., 0.]
62       [0., 0.]]
63
64  operator.shape
65  ==> [2, 2]
66
67  operator.determinant()
68  ==> 0.
69
70  x = ... Shape [2, 4] Tensor
71  operator.matmul(x)
72  ==> Shape [2, 4] Tensor, same as x.
73
74  # Create a 2-batch of 2x2 zero matrices
75  operator = LinearOperatorZeros(num_rows=2, batch_shape=[2])
76  operator.to_dense()
77  ==> [[[0., 0.]
78        [0., 0.]],
79       [[0., 0.]
80        [0., 0.]]]
81
82  # Here, even though the operator has a batch shape, the input is the same as
83  # the output, so x can be passed through without a copy.  The operator is able
84  # to detect that no broadcast is necessary because both x and the operator
85  # have statically defined shape.
86  x = ... Shape [2, 2, 3]
87  operator.matmul(x)
88  ==> Shape [2, 2, 3] Tensor, same as tf.zeros_like(x)
89
90  # Here the operator and x have different batch_shape, and are broadcast.
91  # This requires a copy, since the output is different size than the input.
92  x = ... Shape [1, 2, 3]
93  operator.matmul(x)
94  ==> Shape [2, 2, 3] Tensor, equal to tf.zeros_like([x, x])
95  ```
96
97  ### Shape compatibility
98
99  This operator acts on [batch] matrix with compatible shape.
100  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
101
102  ```
103  operator.shape = [B1,...,Bb] + [N, M],  with b >= 0
104  x.shape =   [C1,...,Cc] + [M, R],
105  and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
106  ```
107
108  #### Matrix property hints
109
110  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
111  for `X = non_singular, self_adjoint, positive_definite, square`.
112  These have the following meaning:
113
114  * If `is_X == True`, callers should expect the operator to have the
115    property `X`.  This is a promise that should be fulfilled, but is *not* a
116    runtime assert.  For example, finite floating point precision may result
117    in these promises being violated.
118  * If `is_X == False`, callers should expect the operator to not have `X`.
119  * If `is_X == None` (the default), callers should have no expectation either
120    way.
121  """
122
123  def __init__(self,
124               num_rows,
125               num_columns=None,
126               batch_shape=None,
127               dtype=None,
128               is_non_singular=False,
129               is_self_adjoint=True,
130               is_positive_definite=False,
131               is_square=True,
132               assert_proper_shapes=False,
133               name="LinearOperatorZeros"):
134    r"""Initialize a `LinearOperatorZeros`.
135
136    The `LinearOperatorZeros` is initialized with arguments defining `dtype`
137    and shape.
138
139    This operator is able to broadcast the leading (batch) dimensions, which
140    sometimes requires copying data.  If `batch_shape` is `None`, the operator
141    can take arguments of any batch shape without copying.  See examples.
142
143    Args:
144      num_rows:  Scalar non-negative integer `Tensor`.  Number of rows in the
145        corresponding zero matrix.
146      num_columns:  Scalar non-negative integer `Tensor`.  Number of columns in
147        the corresponding zero matrix. If `None`, defaults to the value of
148        `num_rows`.
149      batch_shape:  Optional `1-D` integer `Tensor`.  The shape of the leading
150        dimensions.  If `None`, this operator has no leading dimensions.
151      dtype:  Data type of the matrix that this operator represents.
152      is_non_singular:  Expect that this operator is non-singular.
153      is_self_adjoint:  Expect that this operator is equal to its hermitian
154        transpose.
155      is_positive_definite:  Expect that this operator is positive definite,
156        meaning the quadratic form `x^H A x` has positive real part for all
157        nonzero `x`.  Note that we do not require the operator to be
158        self-adjoint to be positive-definite.  See:
159        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
160      is_square:  Expect that this operator acts like square [batch] matrices.
161      assert_proper_shapes:  Python `bool`.  If `False`, only perform static
162        checks that initialization and method arguments have proper shape.
163        If `True`, and static checks are inconclusive, add asserts to the graph.
164      name: A name for this `LinearOperator`
165
166    Raises:
167      ValueError:  If `num_rows` is determined statically to be non-scalar, or
168        negative.
169      ValueError:  If `num_columns` is determined statically to be non-scalar,
170        or negative.
171      ValueError:  If `batch_shape` is determined statically to not be 1-D, or
172        negative.
173      ValueError:  If any of the following is not `True`:
174        `{is_self_adjoint, is_non_singular, is_positive_definite}`.
175    """
176    parameters = dict(
177        num_rows=num_rows,
178        num_columns=num_columns,
179        batch_shape=batch_shape,
180        dtype=dtype,
181        is_non_singular=is_non_singular,
182        is_self_adjoint=is_self_adjoint,
183        is_positive_definite=is_positive_definite,
184        is_square=is_square,
185        assert_proper_shapes=assert_proper_shapes,
186        name=name
187    )
188
189    dtype = dtype or dtypes.float32
190    self._assert_proper_shapes = assert_proper_shapes
191
192    with ops.name_scope(name):
193      dtype = dtypes.as_dtype(dtype)
194      if not is_self_adjoint and is_square:
195        raise ValueError("A zero operator is always self adjoint.")
196      if is_non_singular:
197        raise ValueError("A zero operator is always singular.")
198      if is_positive_definite:
199        raise ValueError("A zero operator is always not positive-definite.")
200
201      super(LinearOperatorZeros, self).__init__(
202          dtype=dtype,
203          is_non_singular=is_non_singular,
204          is_self_adjoint=is_self_adjoint,
205          is_positive_definite=is_positive_definite,
206          is_square=is_square,
207          parameters=parameters,
208          name=name)
209
210      linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
211      linear_operator_util.assert_not_ref_type(num_columns, "num_columns")
212      linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
213
214      self._num_rows = linear_operator_util.shape_tensor(
215          num_rows, name="num_rows")
216      self._num_rows_static = tensor_util.constant_value(self._num_rows)
217
218      if num_columns is None:
219        num_columns = num_rows
220
221      self._num_columns = linear_operator_util.shape_tensor(
222          num_columns, name="num_columns")
223      self._num_columns_static = tensor_util.constant_value(self._num_columns)
224
225      self._check_domain_range_possibly_add_asserts()
226
227      if (self._num_rows_static is not None and
228          self._num_columns_static is not None):
229        if is_square and self._num_rows_static != self._num_columns_static:
230          raise ValueError(
231              "LinearOperatorZeros initialized as is_square=True, but got "
232              "num_rows({}) != num_columns({})".format(
233                  self._num_rows_static,
234                  self._num_columns_static))
235
236      if batch_shape is None:
237        self._batch_shape_arg = None
238      else:
239        self._batch_shape_arg = linear_operator_util.shape_tensor(
240            batch_shape, name="batch_shape_arg")
241        self._batch_shape_static = tensor_util.constant_value(
242            self._batch_shape_arg)
243        self._check_batch_shape_possibly_add_asserts()
244
245  def _shape(self):
246    matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
247                                             self._num_columns_static))
248    if self._batch_shape_arg is None:
249      return matrix_shape
250
251    batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
252    return batch_shape.concatenate(matrix_shape)
253
254  def _shape_tensor(self):
255    matrix_shape = array_ops.stack((self._num_rows, self._num_columns), axis=0)
256    if self._batch_shape_arg is None:
257      return matrix_shape
258
259    return array_ops.concat((self._batch_shape_arg, matrix_shape), 0)
260
261  def _assert_non_singular(self):
262    raise errors.InvalidArgumentError(
263        node_def=None, op=None, message="Zero operators are always "
264        "non-invertible.")
265
266  def _assert_positive_definite(self):
267    raise errors.InvalidArgumentError(
268        node_def=None, op=None, message="Zero operators are always "
269        "non-positive definite.")
270
271  def _assert_self_adjoint(self):
272    return control_flow_ops.no_op("assert_self_adjoint")
273
274  def _possibly_broadcast_batch_shape(self, x):
275    """Return 'x', possibly after broadcasting the leading dimensions."""
276    # If we have no batch shape, our batch shape broadcasts with everything!
277    if self._batch_shape_arg is None:
278      return x
279
280    # Static attempt:
281    #   If we determine that no broadcast is necessary, pass x through
282    #   If we need a broadcast, add to an array of zeros.
283    #
284    # special_shape is the shape that, when broadcast with x's shape, will give
285    # the correct broadcast_shape.  Note that
286    #   We have already verified the second to last dimension of self.shape
287    #   matches x's shape in assert_compatible_matrix_dimensions.
288    #   Also, the final dimension of 'x' can have any shape.
289    #   Therefore, the final two dimensions of special_shape are 1's.
290    special_shape = self.batch_shape.concatenate([1, 1])
291    bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
292    if special_shape.is_fully_defined():
293      # bshape.is_fully_defined iff special_shape.is_fully_defined.
294      if bshape == x.shape:
295        return x
296      # Use the built in broadcasting of addition.
297      zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
298      return x + zeros
299
300    # Dynamic broadcast:
301    #   Always add to an array of zeros, rather than using a "cond", since a
302    #   cond would require copying data from GPU --> CPU.
303    special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
304    zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
305    return x + zeros
306
307  def _matmul(self, x, adjoint=False, adjoint_arg=False):
308    if self._assert_proper_shapes:
309      x = linalg.adjoint(x) if adjoint_arg else x
310      aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
311      x = control_flow_ops.with_dependencies([aps], x)
312    if self.is_square:
313      # Note that adjoint has no effect since this matrix is self-adjoint.
314      if adjoint_arg:
315        output_shape = array_ops.concat([
316            array_ops.shape(x)[:-2],
317            [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0)
318      else:
319        output_shape = array_ops.shape(x)
320
321      return self._possibly_broadcast_batch_shape(
322          array_ops.zeros(shape=output_shape, dtype=x.dtype))
323
324    x_shape = array_ops.shape(x)
325    n = self._num_columns if adjoint else self._num_rows
326    m = x_shape[-2] if adjoint_arg else x_shape[-1]
327
328    output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)
329
330    zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
331    return self._possibly_broadcast_batch_shape(zeros)
332
333  def _determinant(self):
334    if self.batch_shape.is_fully_defined():
335      return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
336    else:
337      return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
338
339  def _trace(self):
340    # Get Tensor of all zeros of same shape as self.batch_shape.
341    if self.batch_shape.is_fully_defined():
342      return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
343    else:
344      return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
345
346  def _diag_part(self):
347    return self._zeros_diag()
348
349  def add_to_tensor(self, mat, name="add_to_tensor"):
350    """Add matrix represented by this operator to `mat`.  Equiv to `I + mat`.
351
352    Args:
353      mat:  `Tensor` with same `dtype` and shape broadcastable to `self`.
354      name:  A name to give this `Op`.
355
356    Returns:
357      A `Tensor` with broadcast shape and same `dtype` as `self`.
358    """
359    return self._possibly_broadcast_batch_shape(mat)
360
361  def _check_domain_range_possibly_add_asserts(self):
362    """Static check of init arg `num_rows`, possibly add asserts."""
363    # Possibly add asserts.
364    if self._assert_proper_shapes:
365      self._num_rows = control_flow_ops.with_dependencies([
366          check_ops.assert_rank(
367              self._num_rows,
368              0,
369              message="Argument num_rows must be a 0-D Tensor."),
370          check_ops.assert_non_negative(
371              self._num_rows,
372              message="Argument num_rows must be non-negative."),
373      ], self._num_rows)
374      self._num_columns = control_flow_ops.with_dependencies([
375          check_ops.assert_rank(
376              self._num_columns,
377              0,
378              message="Argument num_columns must be a 0-D Tensor."),
379          check_ops.assert_non_negative(
380              self._num_columns,
381              message="Argument num_columns must be non-negative."),
382      ], self._num_columns)
383
384    # Static checks.
385    if not self._num_rows.dtype.is_integer:
386      raise TypeError("Argument num_rows must be integer type.  Found:"
387                      " %s" % self._num_rows)
388
389    if not self._num_columns.dtype.is_integer:
390      raise TypeError("Argument num_columns must be integer type.  Found:"
391                      " %s" % self._num_columns)
392
393    num_rows_static = self._num_rows_static
394    num_columns_static = self._num_columns_static
395
396    if num_rows_static is not None:
397      if num_rows_static.ndim != 0:
398        raise ValueError("Argument num_rows must be a 0-D Tensor.  Found:"
399                         " %s" % num_rows_static)
400
401      if num_rows_static < 0:
402        raise ValueError("Argument num_rows must be non-negative.  Found:"
403                         " %s" % num_rows_static)
404    if num_columns_static is not None:
405      if num_columns_static.ndim != 0:
406        raise ValueError("Argument num_columns must be a 0-D Tensor.  Found:"
407                         " %s" % num_columns_static)
408
409      if num_columns_static < 0:
410        raise ValueError("Argument num_columns must be non-negative.  Found:"
411                         " %s" % num_columns_static)
412
413  def _check_batch_shape_possibly_add_asserts(self):
414    """Static check of init arg `batch_shape`, possibly add asserts."""
415    if self._batch_shape_arg is None:
416      return
417
418    # Possibly add asserts
419    if self._assert_proper_shapes:
420      self._batch_shape_arg = control_flow_ops.with_dependencies([
421          check_ops.assert_rank(
422              self._batch_shape_arg,
423              1,
424              message="Argument batch_shape must be a 1-D Tensor."),
425          check_ops.assert_non_negative(
426              self._batch_shape_arg,
427              message="Argument batch_shape must be non-negative."),
428      ], self._batch_shape_arg)
429
430    # Static checks
431    if not self._batch_shape_arg.dtype.is_integer:
432      raise TypeError("Argument batch_shape must be integer type.  Found:"
433                      " %s" % self._batch_shape_arg)
434
435    if self._batch_shape_static is None:
436      return  # Cannot do any other static checks.
437
438    if self._batch_shape_static.ndim != 1:
439      raise ValueError("Argument batch_shape must be a 1-D Tensor.  Found:"
440                       " %s" % self._batch_shape_static)
441
442    if np.any(self._batch_shape_static < 0):
443      raise ValueError("Argument batch_shape must be non-negative.  Found:"
444                       "%s" % self._batch_shape_static)
445
446  def _min_matrix_dim(self):
447    """Minimum of domain/range dimension, if statically available, else None."""
448    domain_dim = self.domain_dimension.value
449    range_dim = self.range_dimension.value
450    if domain_dim is None or range_dim is None:
451      return None
452    return min(domain_dim, range_dim)
453
454  def _min_matrix_dim_tensor(self):
455    """Minimum of domain/range dimension, as a tensor."""
456    return math_ops.reduce_min(self.shape_tensor()[-2:])
457
458  def _zeros_diag(self):
459    """Returns the diagonal of this operator as all zeros."""
460    if self.shape.is_fully_defined():
461      d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
462    else:
463      d_shape = array_ops.concat(
464          [self.batch_shape_tensor(),
465           [self._min_matrix_dim_tensor()]], axis=0)
466
467    return array_ops.zeros(shape=d_shape, dtype=self.dtype)
468
469  def _eigvals(self):
470    return self._zeros_diag()
471
472  @property
473  def _composite_tensor_prefer_static_fields(self):
474    return ("num_rows", "num_columns", "batch_shape")
475
476  @property
477  def _composite_tensor_fields(self):
478    return ("num_rows", "num_columns", "batch_shape", "dtype",
479            "assert_proper_shapes")
480
481  def __getitem__(self, slices):
482    # Slice the batch shape and return a new LinearOperatorIdentity.
483    # Use a proxy shape and slice it. Use this as the new batch shape
484    new_batch_shape = array_ops.shape(
485        array_ops.ones(self._batch_shape_arg)[slices])
486    parameters = dict(self.parameters, batch_shape=new_batch_shape)
487    return LinearOperatorZeros(**parameters)
488
489