1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FillTriangular bijector.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops.distributions import bijector 30from tensorflow.python.ops.distributions import util as dist_util 31from tensorflow.python.util import deprecation 32 33 34__all__ = [ 35 "FillTriangular", 36] 37 38 39class FillTriangular(bijector.Bijector): 40 """Transforms vectors to triangular. 41 42 Triangular matrix elements are filled in a clockwise spiral. 43 44 Given input with shape `batch_shape + [d]`, produces output with 45 shape `batch_shape + [n, n]`, where 46 `n = (-1 + sqrt(1 + 8 * d))/2`. 47 This follows by solving the quadratic equation 48 `d = 1 + 2 + ... + n = n * (n + 1)/2`. 49 50 #### Example 51 52 ```python 53 b = tfb.FillTriangular(upper=False) 54 b.forward([1, 2, 3, 4, 5, 6]) 55 # ==> [[4, 0, 0], 56 # [6, 5, 0], 57 # [3, 2, 1]] 58 59 b = tfb.FillTriangular(upper=True) 60 b.forward([1, 2, 3, 4, 5, 6]) 61 # ==> [[1, 2, 3], 62 # [0, 5, 6], 63 # [0, 0, 4]] 64 65 ``` 66 """ 67 68 @deprecation.deprecated( 69 "2018-10-01", 70 "The TensorFlow Distributions library has moved to " 71 "TensorFlow Probability " 72 "(https://github.com/tensorflow/probability). You " 73 "should update all references to use `tfp.distributions` " 74 "instead of `tf.contrib.distributions`.", 75 warn_once=True) 76 def __init__(self, 77 upper=False, 78 validate_args=False, 79 name="fill_triangular"): 80 """Instantiates the `FillTriangular` bijector. 81 82 Args: 83 upper: Python `bool` representing whether output matrix should be upper 84 triangular (`True`) or lower triangular (`False`, default). 85 validate_args: Python `bool` indicating whether arguments should be 86 checked for correctness. 87 name: Python `str` name given to ops managed by this object. 88 """ 89 self._upper = upper 90 super(FillTriangular, self).__init__( 91 forward_min_event_ndims=1, 92 inverse_min_event_ndims=2, 93 validate_args=validate_args, 94 name=name) 95 96 def _forward(self, x): 97 return dist_util.fill_triangular(x, upper=self._upper) 98 99 def _inverse(self, y): 100 return dist_util.fill_triangular_inverse(y, upper=self._upper) 101 102 def _forward_log_det_jacobian(self, x): 103 return array_ops.zeros_like(x[..., 0]) 104 105 def _inverse_log_det_jacobian(self, y): 106 return array_ops.zeros_like(y[..., 0, 0]) 107 108 def _forward_event_shape(self, input_shape): 109 batch_shape, d = (input_shape[:-1], 110 tensor_shape.dimension_value(input_shape[-1])) 111 if d is None: 112 n = None 113 else: 114 n = vector_size_to_square_matrix_size(d, self.validate_args) 115 return batch_shape.concatenate([n, n]) 116 117 def _inverse_event_shape(self, output_shape): 118 batch_shape, n1, n2 = (output_shape[:-2], 119 tensor_shape.dimension_value(output_shape[-2]), 120 tensor_shape.dimension_value(output_shape[-1])) 121 if n1 is None or n2 is None: 122 m = None 123 elif n1 != n2: 124 raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) 125 else: 126 m = n1 * (n1 + 1) / 2 127 return batch_shape.concatenate([m]) 128 129 def _forward_event_shape_tensor(self, input_shape_tensor): 130 batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] 131 n = vector_size_to_square_matrix_size(d, self.validate_args) 132 return array_ops.concat([batch_shape, [n, n]], axis=0) 133 134 def _inverse_event_shape_tensor(self, output_shape_tensor): 135 batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] 136 if self.validate_args: 137 is_square_matrix = check_ops.assert_equal( 138 n, output_shape_tensor[-2], message="Matrix must be square.") 139 with ops.control_dependencies([is_square_matrix]): 140 n = array_ops.identity(n) 141 d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) 142 return array_ops.concat([batch_shape, [d]], axis=0) 143 144 145@deprecation.deprecated( 146 "2018-10-01", 147 "The TensorFlow Distributions library has moved to " 148 "TensorFlow Probability " 149 "(https://github.com/tensorflow/probability). You " 150 "should update all references to use `tfp.distributions` " 151 "instead of `tf.contrib.distributions`.", 152 warn_once=True) 153def vector_size_to_square_matrix_size(d, validate_args, name=None): 154 """Convert a vector size to a matrix size.""" 155 if isinstance(d, (float, int, np.generic, np.ndarray)): 156 n = (-1 + np.sqrt(1 + 8 * d)) / 2. 157 if float(int(n)) != n: 158 raise ValueError("Vector length is not a triangular number.") 159 return int(n) 160 else: 161 with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: 162 n = (-1. + math_ops.sqrt(1 + 8. * math_ops.cast(d, dtypes.float32))) / 2. 163 if validate_args: 164 with ops.control_dependencies([ 165 check_ops.assert_equal( 166 math_ops.cast(math_ops.cast(n, dtypes.int32), dtypes.float32), 167 n, 168 message="Vector length is not a triangular number") 169 ]): 170 n = array_ops.identity(n) 171 return math_ops.cast(n, d.dtype) 172