1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Composes one or more `LinearOperators`.""" 16 17from tensorflow.python.framework import common_shapes 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops.linalg import linear_operator 25from tensorflow.python.util.tf_export import tf_export 26 27__all__ = ["LinearOperatorComposition"] 28 29 30@tf_export("linalg.LinearOperatorComposition") 31@linear_operator.make_composite_tensor 32class LinearOperatorComposition(linear_operator.LinearOperator): 33 """Composes one or more `LinearOperators`. 34 35 This operator composes one or more linear operators `[op1,...,opJ]`, 36 building a new `LinearOperator` with action defined by: 37 38 ``` 39 op_composed(x) := op1(op2(...(opJ(x)...)) 40 ``` 41 42 If `opj` acts like [batch] matrix `Aj`, then `op_composed` acts like the 43 [batch] matrix formed with the multiplication `A1 A2...AJ`. 44 45 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then we must have 46 `N_j = M_{j+1}`, in which case the composed operator has shape equal to 47 `broadcast_batch_shape + [M_1, N_J]`, where `broadcast_batch_shape` is the 48 mutual broadcast of `batch_shape_j`, `j = 1,...,J`, assuming the intermediate 49 batch shapes broadcast. Even if the composed shape is well defined, the 50 composed operator's methods may fail due to lack of broadcasting ability in 51 the defining operators' methods. 52 53 ```python 54 # Create a 2 x 2 linear operator composed of two 2 x 2 operators. 55 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 56 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 57 operator = LinearOperatorComposition([operator_1, operator_2]) 58 59 operator.to_dense() 60 ==> [[1., 2.] 61 [3., 4.]] 62 63 operator.shape 64 ==> [2, 2] 65 66 operator.log_abs_determinant() 67 ==> scalar Tensor 68 69 x = ... Shape [2, 4] Tensor 70 operator.matmul(x) 71 ==> Shape [2, 4] Tensor 72 73 # Create a [2, 3] batch of 4 x 5 linear operators. 74 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5]) 75 operator_45 = LinearOperatorFullMatrix(matrix) 76 77 # Create a [2, 3] batch of 5 x 6 linear operators. 78 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6]) 79 operator_56 = LinearOperatorFullMatrix(matrix_56) 80 81 # Compose to create a [2, 3] batch of 4 x 6 operators. 82 operator_46 = LinearOperatorComposition([operator_45, operator_56]) 83 84 # Create a shape [2, 3, 6, 2] vector. 85 x = tf.random.normal(shape=[2, 3, 6, 2]) 86 operator.matmul(x) 87 ==> Shape [2, 3, 4, 2] Tensor 88 ``` 89 90 #### Performance 91 92 The performance of `LinearOperatorComposition` on any operation is equal to 93 the sum of the individual operators' operations. 94 95 96 #### Matrix property hints 97 98 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 99 for `X = non_singular, self_adjoint, positive_definite, square`. 100 These have the following meaning: 101 102 * If `is_X == True`, callers should expect the operator to have the 103 property `X`. This is a promise that should be fulfilled, but is *not* a 104 runtime assert. For example, finite floating point precision may result 105 in these promises being violated. 106 * If `is_X == False`, callers should expect the operator to not have `X`. 107 * If `is_X == None` (the default), callers should have no expectation either 108 way. 109 """ 110 111 def __init__(self, 112 operators, 113 is_non_singular=None, 114 is_self_adjoint=None, 115 is_positive_definite=None, 116 is_square=None, 117 name=None): 118 r"""Initialize a `LinearOperatorComposition`. 119 120 `LinearOperatorComposition` is initialized with a list of operators 121 `[op_1,...,op_J]`. For the `matmul` method to be well defined, the 122 composition `op_i.matmul(op_{i+1}(x))` must be defined. Other methods have 123 similar constraints. 124 125 Args: 126 operators: Iterable of `LinearOperator` objects, each with 127 the same `dtype` and composable shape. 128 is_non_singular: Expect that this operator is non-singular. 129 is_self_adjoint: Expect that this operator is equal to its hermitian 130 transpose. 131 is_positive_definite: Expect that this operator is positive definite, 132 meaning the quadratic form `x^H A x` has positive real part for all 133 nonzero `x`. Note that we do not require the operator to be 134 self-adjoint to be positive-definite. See: 135 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 136 is_square: Expect that this operator acts like square [batch] matrices. 137 name: A name for this `LinearOperator`. Default is the individual 138 operators names joined with `_o_`. 139 140 Raises: 141 TypeError: If all operators do not have the same `dtype`. 142 ValueError: If `operators` is empty. 143 """ 144 parameters = dict( 145 operators=operators, 146 is_non_singular=is_non_singular, 147 is_self_adjoint=is_self_adjoint, 148 is_positive_definite=is_positive_definite, 149 is_square=is_square, 150 name=name) 151 152 # Validate operators. 153 check_ops.assert_proper_iterable(operators) 154 operators = list(operators) 155 if not operators: 156 raise ValueError( 157 "Expected a non-empty list of operators. Found: %s" % operators) 158 self._operators = operators 159 160 # Validate dtype. 161 dtype = operators[0].dtype 162 for operator in operators: 163 if operator.dtype != dtype: 164 name_type = (str((o.name, o.dtype)) for o in operators) 165 raise TypeError( 166 "Expected all operators to have the same dtype. Found %s" 167 % " ".join(name_type)) 168 169 # Auto-set and check hints. 170 if all(operator.is_non_singular for operator in operators): 171 if is_non_singular is False: # pylint:disable=g-bool-id-comparison 172 raise ValueError( 173 "The composition of non-singular operators is always non-singular.") 174 is_non_singular = True 175 176 # Initialization. 177 178 if name is None: 179 name = "_o_".join(operator.name for operator in operators) 180 with ops.name_scope(name): 181 super(LinearOperatorComposition, self).__init__( 182 dtype=dtype, 183 is_non_singular=is_non_singular, 184 is_self_adjoint=is_self_adjoint, 185 is_positive_definite=is_positive_definite, 186 is_square=is_square, 187 parameters=parameters, 188 name=name) 189 190 @property 191 def operators(self): 192 return self._operators 193 194 def _shape(self): 195 # Get final matrix shape. 196 domain_dimension = self.operators[0].domain_dimension 197 for operator in self.operators[1:]: 198 domain_dimension.assert_is_compatible_with(operator.range_dimension) 199 domain_dimension = operator.domain_dimension 200 201 matrix_shape = tensor_shape.TensorShape( 202 [self.operators[0].range_dimension, 203 self.operators[-1].domain_dimension]) 204 205 # Get broadcast batch shape. 206 # broadcast_shape checks for compatibility. 207 batch_shape = self.operators[0].batch_shape 208 for operator in self.operators[1:]: 209 batch_shape = common_shapes.broadcast_shape( 210 batch_shape, operator.batch_shape) 211 212 return batch_shape.concatenate(matrix_shape) 213 214 def _shape_tensor(self): 215 # Avoid messy broadcasting if possible. 216 if self.shape.is_fully_defined(): 217 return ops.convert_to_tensor( 218 self.shape.as_list(), dtype=dtypes.int32, name="shape") 219 220 # Don't check the matrix dimensions. That would add unnecessary Asserts to 221 # the graph. Things will fail at runtime naturally if shapes are 222 # incompatible. 223 matrix_shape = array_ops.stack([ 224 self.operators[0].range_dimension_tensor(), 225 self.operators[-1].domain_dimension_tensor() 226 ]) 227 228 # Dummy Tensor of zeros. Will never be materialized. 229 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) 230 for operator in self.operators[1:]: 231 zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) 232 batch_shape = array_ops.shape(zeros) 233 234 return array_ops.concat((batch_shape, matrix_shape), 0) 235 236 def _matmul(self, x, adjoint=False, adjoint_arg=False): 237 # If self.operators = [A, B], and not adjoint, then 238 # matmul_order_list = [B, A]. 239 # As a result, we return A.matmul(B.matmul(x)) 240 if adjoint: 241 matmul_order_list = self.operators 242 else: 243 matmul_order_list = list(reversed(self.operators)) 244 245 result = matmul_order_list[0].matmul( 246 x, adjoint=adjoint, adjoint_arg=adjoint_arg) 247 for operator in matmul_order_list[1:]: 248 result = operator.matmul(result, adjoint=adjoint) 249 return result 250 251 def _determinant(self): 252 result = self.operators[0].determinant() 253 for operator in self.operators[1:]: 254 result *= operator.determinant() 255 return result 256 257 def _log_abs_determinant(self): 258 result = self.operators[0].log_abs_determinant() 259 for operator in self.operators[1:]: 260 result += operator.log_abs_determinant() 261 return result 262 263 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 264 # TODO(langmore) Implement solve using solve_ls if some intermediate 265 # operator maps to a high dimensional space. 266 # In that case, an exact solve may still be possible. 267 268 # If self.operators = [A, B], and not adjoint, then 269 # solve_order_list = [A, B]. 270 # As a result, we return B.solve(A.solve(x)) 271 if adjoint: 272 solve_order_list = list(reversed(self.operators)) 273 else: 274 solve_order_list = self.operators 275 276 solution = solve_order_list[0].solve( 277 rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 278 for operator in solve_order_list[1:]: 279 solution = operator.solve(solution, adjoint=adjoint) 280 return solution 281 282 def _assert_non_singular(self): 283 if all(operator.is_square for operator in self.operators): 284 asserts = [operator.assert_non_singular() for operator in self.operators] 285 return control_flow_ops.group(asserts) 286 return super(LinearOperatorComposition, self)._assert_non_singular() 287 288 @property 289 def _composite_tensor_fields(self): 290 return ("operators",) 291 292 @property 293 def _experimental_parameter_ndims_to_matrix_ndims(self): 294 return {"operators": [0] * len(self.operators)} 295