• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FillTriangular bijector."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.distributions import bijector
30from tensorflow.python.ops.distributions import util as dist_util
31from tensorflow.python.util import deprecation
32
33
34__all__ = [
35    "FillTriangular",
36]
37
38
39class FillTriangular(bijector.Bijector):
40  """Transforms vectors to triangular.
41
42  Triangular matrix elements are filled in a clockwise spiral.
43
44  Given input with shape `batch_shape + [d]`, produces output with
45  shape `batch_shape + [n, n]`, where
46   `n = (-1 + sqrt(1 + 8 * d))/2`.
47  This follows by solving the quadratic equation
48   `d = 1 + 2 + ... + n = n * (n + 1)/2`.
49
50  #### Example
51
52  ```python
53  b = tfb.FillTriangular(upper=False)
54  b.forward([1, 2, 3, 4, 5, 6])
55  # ==> [[4, 0, 0],
56  #      [6, 5, 0],
57  #      [3, 2, 1]]
58
59  b = tfb.FillTriangular(upper=True)
60  b.forward([1, 2, 3, 4, 5, 6])
61  # ==> [[1, 2, 3],
62  #      [0, 5, 6],
63  #      [0, 0, 4]]
64
65  ```
66  """
67
68  @deprecation.deprecated(
69      "2018-10-01",
70      "The TensorFlow Distributions library has moved to "
71      "TensorFlow Probability "
72      "(https://github.com/tensorflow/probability). You "
73      "should update all references to use `tfp.distributions` "
74      "instead of `tf.contrib.distributions`.",
75      warn_once=True)
76  def __init__(self,
77               upper=False,
78               validate_args=False,
79               name="fill_triangular"):
80    """Instantiates the `FillTriangular` bijector.
81
82    Args:
83      upper: Python `bool` representing whether output matrix should be upper
84        triangular (`True`) or lower triangular (`False`, default).
85      validate_args: Python `bool` indicating whether arguments should be
86        checked for correctness.
87      name: Python `str` name given to ops managed by this object.
88    """
89    self._upper = upper
90    super(FillTriangular, self).__init__(
91        forward_min_event_ndims=1,
92        inverse_min_event_ndims=2,
93        validate_args=validate_args,
94        name=name)
95
96  def _forward(self, x):
97    return dist_util.fill_triangular(x, upper=self._upper)
98
99  def _inverse(self, y):
100    return dist_util.fill_triangular_inverse(y, upper=self._upper)
101
102  def _forward_log_det_jacobian(self, x):
103    return array_ops.zeros_like(x[..., 0])
104
105  def _inverse_log_det_jacobian(self, y):
106    return array_ops.zeros_like(y[..., 0, 0])
107
108  def _forward_event_shape(self, input_shape):
109    batch_shape, d = (input_shape[:-1],
110                      tensor_shape.dimension_value(input_shape[-1]))
111    if d is None:
112      n = None
113    else:
114      n = vector_size_to_square_matrix_size(d, self.validate_args)
115    return batch_shape.concatenate([n, n])
116
117  def _inverse_event_shape(self, output_shape):
118    batch_shape, n1, n2 = (output_shape[:-2],
119                           tensor_shape.dimension_value(output_shape[-2]),
120                           tensor_shape.dimension_value(output_shape[-1]))
121    if n1 is None or n2 is None:
122      m = None
123    elif n1 != n2:
124      raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2))
125    else:
126      m = n1 * (n1 + 1) / 2
127    return batch_shape.concatenate([m])
128
129  def _forward_event_shape_tensor(self, input_shape_tensor):
130    batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1]
131    n = vector_size_to_square_matrix_size(d, self.validate_args)
132    return array_ops.concat([batch_shape, [n, n]], axis=0)
133
134  def _inverse_event_shape_tensor(self, output_shape_tensor):
135    batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
136    if self.validate_args:
137      is_square_matrix = check_ops.assert_equal(
138          n, output_shape_tensor[-2], message="Matrix must be square.")
139      with ops.control_dependencies([is_square_matrix]):
140        n = array_ops.identity(n)
141    d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
142    return array_ops.concat([batch_shape, [d]], axis=0)
143
144
145@deprecation.deprecated(
146    "2018-10-01",
147    "The TensorFlow Distributions library has moved to "
148    "TensorFlow Probability "
149    "(https://github.com/tensorflow/probability). You "
150    "should update all references to use `tfp.distributions` "
151    "instead of `tf.contrib.distributions`.",
152    warn_once=True)
153def vector_size_to_square_matrix_size(d, validate_args, name=None):
154  """Convert a vector size to a matrix size."""
155  if isinstance(d, (float, int, np.generic, np.ndarray)):
156    n = (-1 + np.sqrt(1 + 8 * d)) / 2.
157    if float(int(n)) != n:
158      raise ValueError("Vector length is not a triangular number.")
159    return int(n)
160  else:
161    with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name:
162      n = (-1. + math_ops.sqrt(1 + 8. * math_ops.cast(d, dtypes.float32))) / 2.
163      if validate_args:
164        with ops.control_dependencies([
165            check_ops.assert_equal(
166                math_ops.cast(math_ops.cast(n, dtypes.int32), dtypes.float32),
167                n,
168                message="Vector length is not a triangular number")
169        ]):
170          n = array_ops.identity(n)
171      return math_ops.cast(n, d.dtype)
172