1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Composes one or more `LinearOperators`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import common_shapes 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops.linalg import linear_operator 28from tensorflow.python.util.tf_export import tf_export 29 30__all__ = ["LinearOperatorComposition"] 31 32 33@tf_export("linalg.LinearOperatorComposition") 34@linear_operator.make_composite_tensor 35class LinearOperatorComposition(linear_operator.LinearOperator): 36 """Composes one or more `LinearOperators`. 37 38 This operator composes one or more linear operators `[op1,...,opJ]`, 39 building a new `LinearOperator` with action defined by: 40 41 ``` 42 op_composed(x) := op1(op2(...(opJ(x)...)) 43 ``` 44 45 If `opj` acts like [batch] matrix `Aj`, then `op_composed` acts like the 46 [batch] matrix formed with the multiplication `A1 A2...AJ`. 47 48 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then we must have 49 `N_j = M_{j+1}`, in which case the composed operator has shape equal to 50 `broadcast_batch_shape + [M_1, N_J]`, where `broadcast_batch_shape` is the 51 mutual broadcast of `batch_shape_j`, `j = 1,...,J`, assuming the intermediate 52 batch shapes broadcast. Even if the composed shape is well defined, the 53 composed operator's methods may fail due to lack of broadcasting ability in 54 the defining operators' methods. 55 56 ```python 57 # Create a 2 x 2 linear operator composed of two 2 x 2 operators. 58 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 59 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 60 operator = LinearOperatorComposition([operator_1, operator_2]) 61 62 operator.to_dense() 63 ==> [[1., 2.] 64 [3., 4.]] 65 66 operator.shape 67 ==> [2, 2] 68 69 operator.log_abs_determinant() 70 ==> scalar Tensor 71 72 x = ... Shape [2, 4] Tensor 73 operator.matmul(x) 74 ==> Shape [2, 4] Tensor 75 76 # Create a [2, 3] batch of 4 x 5 linear operators. 77 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5]) 78 operator_45 = LinearOperatorFullMatrix(matrix) 79 80 # Create a [2, 3] batch of 5 x 6 linear operators. 81 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6]) 82 operator_56 = LinearOperatorFullMatrix(matrix_56) 83 84 # Compose to create a [2, 3] batch of 4 x 6 operators. 85 operator_46 = LinearOperatorComposition([operator_45, operator_56]) 86 87 # Create a shape [2, 3, 6, 2] vector. 88 x = tf.random.normal(shape=[2, 3, 6, 2]) 89 operator.matmul(x) 90 ==> Shape [2, 3, 4, 2] Tensor 91 ``` 92 93 #### Performance 94 95 The performance of `LinearOperatorComposition` on any operation is equal to 96 the sum of the individual operators' operations. 97 98 99 #### Matrix property hints 100 101 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 102 for `X = non_singular, self_adjoint, positive_definite, square`. 103 These have the following meaning: 104 105 * If `is_X == True`, callers should expect the operator to have the 106 property `X`. This is a promise that should be fulfilled, but is *not* a 107 runtime assert. For example, finite floating point precision may result 108 in these promises being violated. 109 * If `is_X == False`, callers should expect the operator to not have `X`. 110 * If `is_X == None` (the default), callers should have no expectation either 111 way. 112 """ 113 114 def __init__(self, 115 operators, 116 is_non_singular=None, 117 is_self_adjoint=None, 118 is_positive_definite=None, 119 is_square=None, 120 name=None): 121 r"""Initialize a `LinearOperatorComposition`. 122 123 `LinearOperatorComposition` is initialized with a list of operators 124 `[op_1,...,op_J]`. For the `matmul` method to be well defined, the 125 composition `op_i.matmul(op_{i+1}(x))` must be defined. Other methods have 126 similar constraints. 127 128 Args: 129 operators: Iterable of `LinearOperator` objects, each with 130 the same `dtype` and composable shape. 131 is_non_singular: Expect that this operator is non-singular. 132 is_self_adjoint: Expect that this operator is equal to its hermitian 133 transpose. 134 is_positive_definite: Expect that this operator is positive definite, 135 meaning the quadratic form `x^H A x` has positive real part for all 136 nonzero `x`. Note that we do not require the operator to be 137 self-adjoint to be positive-definite. See: 138 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 139 is_square: Expect that this operator acts like square [batch] matrices. 140 name: A name for this `LinearOperator`. Default is the individual 141 operators names joined with `_o_`. 142 143 Raises: 144 TypeError: If all operators do not have the same `dtype`. 145 ValueError: If `operators` is empty. 146 """ 147 parameters = dict( 148 operators=operators, 149 is_non_singular=is_non_singular, 150 is_self_adjoint=is_self_adjoint, 151 is_positive_definite=is_positive_definite, 152 is_square=is_square, 153 name=name) 154 155 # Validate operators. 156 check_ops.assert_proper_iterable(operators) 157 operators = list(operators) 158 if not operators: 159 raise ValueError( 160 "Expected a non-empty list of operators. Found: %s" % operators) 161 self._operators = operators 162 163 # Validate dtype. 164 dtype = operators[0].dtype 165 for operator in operators: 166 if operator.dtype != dtype: 167 name_type = (str((o.name, o.dtype)) for o in operators) 168 raise TypeError( 169 "Expected all operators to have the same dtype. Found %s" 170 % " ".join(name_type)) 171 172 # Auto-set and check hints. 173 if all(operator.is_non_singular for operator in operators): 174 if is_non_singular is False: 175 raise ValueError( 176 "The composition of non-singular operators is always non-singular.") 177 is_non_singular = True 178 179 # Initialization. 180 graph_parents = [] 181 for operator in operators: 182 graph_parents.extend(operator.graph_parents) 183 184 if name is None: 185 name = "_o_".join(operator.name for operator in operators) 186 with ops.name_scope(name, values=graph_parents): 187 super(LinearOperatorComposition, self).__init__( 188 dtype=dtype, 189 is_non_singular=is_non_singular, 190 is_self_adjoint=is_self_adjoint, 191 is_positive_definite=is_positive_definite, 192 is_square=is_square, 193 parameters=parameters, 194 name=name) 195 # TODO(b/143910018) Remove graph_parents in V3. 196 self._set_graph_parents(graph_parents) 197 198 @property 199 def operators(self): 200 return self._operators 201 202 def _shape(self): 203 # Get final matrix shape. 204 domain_dimension = self.operators[0].domain_dimension 205 for operator in self.operators[1:]: 206 domain_dimension.assert_is_compatible_with(operator.range_dimension) 207 domain_dimension = operator.domain_dimension 208 209 matrix_shape = tensor_shape.TensorShape( 210 [self.operators[0].range_dimension, 211 self.operators[-1].domain_dimension]) 212 213 # Get broadcast batch shape. 214 # broadcast_shape checks for compatibility. 215 batch_shape = self.operators[0].batch_shape 216 for operator in self.operators[1:]: 217 batch_shape = common_shapes.broadcast_shape( 218 batch_shape, operator.batch_shape) 219 220 return batch_shape.concatenate(matrix_shape) 221 222 def _shape_tensor(self): 223 # Avoid messy broadcasting if possible. 224 if self.shape.is_fully_defined(): 225 return ops.convert_to_tensor( 226 self.shape.as_list(), dtype=dtypes.int32, name="shape") 227 228 # Don't check the matrix dimensions. That would add unnecessary Asserts to 229 # the graph. Things will fail at runtime naturally if shapes are 230 # incompatible. 231 matrix_shape = array_ops.stack([ 232 self.operators[0].range_dimension_tensor(), 233 self.operators[-1].domain_dimension_tensor() 234 ]) 235 236 # Dummy Tensor of zeros. Will never be materialized. 237 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) 238 for operator in self.operators[1:]: 239 zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) 240 batch_shape = array_ops.shape(zeros) 241 242 return array_ops.concat((batch_shape, matrix_shape), 0) 243 244 def _matmul(self, x, adjoint=False, adjoint_arg=False): 245 # If self.operators = [A, B], and not adjoint, then 246 # matmul_order_list = [B, A]. 247 # As a result, we return A.matmul(B.matmul(x)) 248 if adjoint: 249 matmul_order_list = self.operators 250 else: 251 matmul_order_list = list(reversed(self.operators)) 252 253 result = matmul_order_list[0].matmul( 254 x, adjoint=adjoint, adjoint_arg=adjoint_arg) 255 for operator in matmul_order_list[1:]: 256 result = operator.matmul(result, adjoint=adjoint) 257 return result 258 259 def _determinant(self): 260 result = self.operators[0].determinant() 261 for operator in self.operators[1:]: 262 result *= operator.determinant() 263 return result 264 265 def _log_abs_determinant(self): 266 result = self.operators[0].log_abs_determinant() 267 for operator in self.operators[1:]: 268 result += operator.log_abs_determinant() 269 return result 270 271 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 272 # TODO(langmore) Implement solve using solve_ls if some intermediate 273 # operator maps to a high dimensional space. 274 # In that case, an exact solve may still be possible. 275 276 # If self.operators = [A, B], and not adjoint, then 277 # solve_order_list = [A, B]. 278 # As a result, we return B.solve(A.solve(x)) 279 if adjoint: 280 solve_order_list = list(reversed(self.operators)) 281 else: 282 solve_order_list = self.operators 283 284 solution = solve_order_list[0].solve( 285 rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 286 for operator in solve_order_list[1:]: 287 solution = operator.solve(solution, adjoint=adjoint) 288 return solution 289 290 @property 291 def _composite_tensor_fields(self): 292 return ("operators",) 293