1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.math_ops.matrix_solve.""" 16 17import numpy as np 18 19from tensorflow.python.client import session 20from tensorflow.python.eager import context 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors_impl 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import linalg_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import benchmark 32from tensorflow.python.platform import test as test_lib 33 34 35def _AddTest(test, op_name, testcase_name, fn): 36 test_name = "_".join(["test", op_name, testcase_name]) 37 if hasattr(test, test_name): 38 raise RuntimeError("Test %s defined more than once" % test_name) 39 setattr(test, test_name, fn) 40 41 42def _GenerateTestData(matrix_shape, num_rhs): 43 batch_shape = matrix_shape[:-2] 44 matrix_shape = matrix_shape[-2:] 45 m = matrix_shape[-2] 46 np.random.seed(1) 47 matrix = np.random.uniform( 48 low=-1.0, high=1.0, 49 size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32) 50 rhs = np.ones([m, num_rhs]).astype(np.float32) 51 matrix = variables.Variable( 52 np.tile(matrix, batch_shape + (1, 1)), trainable=False) 53 rhs = variables.Variable(np.tile(rhs, batch_shape + (1, 1)), trainable=False) 54 return matrix, rhs 55 56 57def _SolveWithNumpy(matrix, rhs, l2_regularizer=0): 58 if l2_regularizer == 0: 59 np_ans, _, _, _ = np.linalg.lstsq(matrix, rhs) 60 return np_ans 61 else: 62 rows = matrix.shape[-2] 63 cols = matrix.shape[-1] 64 if rows >= cols: 65 preconditioner = l2_regularizer * np.identity(cols) 66 gramian = np.dot(np.conj(matrix.T), matrix) + preconditioner 67 rhs = np.dot(np.conj(matrix.T), rhs) 68 return np.linalg.solve(gramian, rhs) 69 else: 70 preconditioner = l2_regularizer * np.identity(rows) 71 gramian = np.dot(matrix, np.conj(matrix.T)) + preconditioner 72 z = np.linalg.solve(gramian, rhs) 73 return np.dot(np.conj(matrix.T), z) 74 75 76@test_util.with_eager_op_as_function 77class MatrixSolveLsOpTest(test_lib.TestCase): 78 79 def _verifySolve(self, 80 x, 81 y, 82 dtype, 83 use_placeholder, 84 fast, 85 l2_regularizer, 86 batch_shape=()): 87 if not fast and l2_regularizer != 0: 88 # The slow path does not support regularization. 89 return 90 if use_placeholder and context.executing_eagerly(): 91 return 92 maxdim = np.max(x.shape) 93 if dtype == np.float32 or dtype == np.complex64: 94 tol = maxdim * 5e-4 95 else: 96 tol = maxdim * 5e-7 97 a = x.astype(dtype) 98 b = y.astype(dtype) 99 if dtype in [np.complex64, np.complex128]: 100 a.imag = a.real 101 b.imag = b.real 102 # numpy.linalg.lstqr does not batching, so we just solve a single system 103 # and replicate the solution. and residual norm. 104 np_ans = _SolveWithNumpy(x, y, l2_regularizer=l2_regularizer) 105 np_r = np.dot(np.conj(a.T), b - np.dot(a, np_ans)) 106 np_r_norm = np.sqrt(np.sum(np.conj(np_r) * np_r)) 107 if batch_shape != (): 108 a = np.tile(a, batch_shape + (1, 1)) 109 b = np.tile(b, batch_shape + (1, 1)) 110 np_ans = np.tile(np_ans, batch_shape + (1, 1)) 111 np_r_norm = np.tile(np_r_norm, batch_shape) 112 if use_placeholder: 113 a_ph = array_ops.placeholder(dtypes.as_dtype(dtype)) 114 b_ph = array_ops.placeholder(dtypes.as_dtype(dtype)) 115 feed_dict = {a_ph: a, b_ph: b} 116 tf_ans = linalg_ops.matrix_solve_ls( 117 a_ph, b_ph, fast=fast, l2_regularizer=l2_regularizer) 118 else: 119 tf_ans = linalg_ops.matrix_solve_ls( 120 a, b, fast=fast, l2_regularizer=l2_regularizer) 121 feed_dict = None 122 self.assertEqual(np_ans.shape, tf_ans.get_shape()) 123 if feed_dict: 124 with self.session() as sess: 125 tf_ans_val = sess.run(tf_ans, feed_dict=feed_dict) 126 else: 127 tf_ans_val = self.evaluate(tf_ans) 128 self.assertEqual(np_ans.shape, tf_ans_val.shape) 129 self.assertAllClose(np_ans, tf_ans_val, atol=2 * tol, rtol=2 * tol) 130 131 if l2_regularizer == 0: 132 # The least squares solution should satisfy A^H * (b - A*x) = 0. 133 tf_r = b - math_ops.matmul(a, tf_ans) 134 tf_r = math_ops.matmul(a, tf_r, adjoint_a=True) 135 tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1]) 136 if feed_dict: 137 with self.session() as sess: 138 tf_ans_val, tf_r_norm_val = sess.run([tf_ans, tf_r_norm], 139 feed_dict=feed_dict) 140 else: 141 tf_ans_val, tf_r_norm_val = self.evaluate([tf_ans, tf_r_norm]) 142 self.assertAllClose(np_r_norm, tf_r_norm_val, atol=tol, rtol=tol) 143 144 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 145 def testWrongDimensions(self): 146 # The matrix and right-hand sides should have the same number of rows. 147 with self.session(): 148 matrix = constant_op.constant([[1., 0.], [0., 1.]]) 149 rhs = constant_op.constant([[1., 0.]]) 150 with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): 151 linalg_ops.matrix_solve_ls(matrix, rhs) 152 153 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 154 def testEmpty(self): 155 full = np.array([[1., 2.], [3., 4.], [5., 6.]]) 156 empty0 = np.empty([3, 0]) 157 empty1 = np.empty([0, 2]) 158 for fast in [True, False]: 159 tf_ans = self.evaluate( 160 linalg_ops.matrix_solve_ls(empty0, empty0, fast=fast)) 161 self.assertEqual(tf_ans.shape, (0, 0)) 162 tf_ans = self.evaluate( 163 linalg_ops.matrix_solve_ls(empty0, full, fast=fast)) 164 self.assertEqual(tf_ans.shape, (0, 2)) 165 tf_ans = self.evaluate( 166 linalg_ops.matrix_solve_ls(full, empty0, fast=fast)) 167 self.assertEqual(tf_ans.shape, (2, 0)) 168 tf_ans = self.evaluate( 169 linalg_ops.matrix_solve_ls(empty1, empty1, fast=fast)) 170 self.assertEqual(tf_ans.shape, (2, 2)) 171 172 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 173 def testBatchResultSize(self): 174 # 3x3x3 matrices, 3x3x1 right-hand sides. 175 matrix = np.array([1., 0., 0., 0., 1., 0., 0., 0., 1.] * 3).reshape(3, 3, 3) # pylint: disable=too-many-function-args 176 rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1) # pylint: disable=too-many-function-args 177 answer = linalg_ops.matrix_solve(matrix, rhs) 178 ls_answer = linalg_ops.matrix_solve_ls(matrix, rhs) 179 self.assertEqual(ls_answer.get_shape(), [3, 3, 1]) 180 self.assertEqual(answer.get_shape(), [3, 3, 1]) 181 182 183def _GetSmallMatrixSolveLsOpTests(dtype, use_placeholder, fast, l2_regularizer): 184 185 def Square(self): 186 # 2x2 matrices, 2x3 right-hand sides. 187 matrix = np.array([[1., 2.], [3., 4.]]) 188 rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) 189 for batch_shape in (), (2, 3): 190 self._verifySolve( 191 matrix, 192 rhs, 193 dtype, 194 use_placeholder, 195 fast, 196 l2_regularizer, 197 batch_shape=batch_shape) 198 199 def Overdetermined(self): 200 # 2x2 matrices, 2x3 right-hand sides. 201 matrix = np.array([[1., 2.], [3., 4.], [5., 6.]]) 202 rhs = np.array([[1., 0., 1.], [0., 1., 1.], [1., 1., 0.]]) 203 for batch_shape in (), (2, 3): 204 self._verifySolve( 205 matrix, 206 rhs, 207 dtype, 208 use_placeholder, 209 fast, 210 l2_regularizer, 211 batch_shape=batch_shape) 212 213 def Underdetermined(self): 214 # 2x2 matrices, 2x3 right-hand sides. 215 matrix = np.array([[1., 2., 3], [4., 5., 6.]]) 216 rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) 217 for batch_shape in (), (2, 3): 218 self._verifySolve( 219 matrix, 220 rhs, 221 dtype, 222 use_placeholder, 223 fast, 224 l2_regularizer, 225 batch_shape=batch_shape) 226 227 return (Square, Overdetermined, Underdetermined) 228 229 230def _GetLargeMatrixSolveLsOpTests(dtype, use_placeholder, fast, l2_regularizer): 231 232 def LargeBatchSquare(self): 233 np.random.seed(1) 234 num_rhs = 1 235 matrix_shape = (127, 127) 236 matrix = np.random.uniform( 237 low=-1.0, high=1.0, 238 size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32) 239 rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32) 240 self._verifySolve( 241 matrix, 242 rhs, 243 dtype, 244 use_placeholder, 245 fast, 246 l2_regularizer, 247 batch_shape=(16, 8)) 248 249 def LargeBatchOverdetermined(self): 250 np.random.seed(1) 251 num_rhs = 1 252 matrix_shape = (127, 64) 253 matrix = np.random.uniform( 254 low=-1.0, high=1.0, 255 size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32) 256 rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32) 257 self._verifySolve( 258 matrix, 259 rhs, 260 dtype, 261 use_placeholder, 262 fast, 263 l2_regularizer, 264 batch_shape=(16, 8)) 265 266 def LargeBatchUnderdetermined(self): 267 np.random.seed(1) 268 num_rhs = 1 269 matrix_shape = (64, 127) 270 matrix = np.random.uniform( 271 low=-1.0, high=1.0, 272 size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32) 273 rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32) 274 self._verifySolve( 275 matrix, 276 rhs, 277 dtype, 278 use_placeholder, 279 fast, 280 l2_regularizer, 281 batch_shape=(16, 8)) 282 283 return (LargeBatchSquare, LargeBatchOverdetermined, LargeBatchUnderdetermined) 284 285 286class MatrixSolveLsBenchmark(test_lib.Benchmark): 287 288 matrix_shapes = [ 289 (4, 4), 290 (8, 4), 291 (4, 8), 292 (10, 10), 293 (10, 8), 294 (8, 10), 295 (16, 16), 296 (16, 10), 297 (10, 16), 298 (101, 101), 299 (101, 31), 300 (31, 101), 301 (256, 256), 302 (256, 200), 303 (200, 256), 304 (1001, 1001), 305 (1001, 501), 306 (501, 1001), 307 (1024, 1024), 308 (1024, 128), 309 (128, 1024), 310 (2048, 2048), 311 (2048, 64), 312 (64, 2048), 313 (513, 4, 4), 314 (513, 4, 2), 315 (513, 2, 4), 316 (513, 16, 16), 317 (513, 16, 10), 318 (513, 10, 16), 319 (513, 256, 256), 320 (513, 256, 128), 321 (513, 128, 256), 322 ] 323 324 def benchmarkMatrixSolveLsOp(self): 325 run_gpu_test = test_lib.is_gpu_available(True) 326 regularizer = 1.0 327 for matrix_shape in self.matrix_shapes: 328 for num_rhs in 1, 2, matrix_shape[-1]: 329 330 with ops.Graph().as_default(), \ 331 session.Session(config=benchmark.benchmark_config()) as sess, \ 332 ops.device("/cpu:0"): 333 matrix, rhs = _GenerateTestData(matrix_shape, num_rhs) 334 x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer) 335 self.evaluate(variables.global_variables_initializer()) 336 self.run_op_benchmark( 337 sess, 338 control_flow_ops.group(x), 339 min_iters=25, 340 store_memory_usage=False, 341 name=("matrix_solve_ls_cpu_shape_{matrix_shape}_num_rhs_{num_rhs}" 342 ).format(matrix_shape=matrix_shape, num_rhs=num_rhs)) 343 344 if run_gpu_test and (len(matrix_shape) < 3 or matrix_shape[0] < 513): 345 with ops.Graph().as_default(), \ 346 session.Session(config=benchmark.benchmark_config()) as sess, \ 347 ops.device("/gpu:0"): 348 matrix, rhs = _GenerateTestData(matrix_shape, num_rhs) 349 x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer) 350 self.evaluate(variables.global_variables_initializer()) 351 self.run_op_benchmark( 352 sess, 353 control_flow_ops.group(x), 354 min_iters=25, 355 store_memory_usage=False, 356 name=("matrix_solve_ls_gpu_shape_{matrix_shape}_num_rhs_" 357 "{num_rhs}").format( 358 matrix_shape=matrix_shape, num_rhs=num_rhs)) 359 360 361if __name__ == "__main__": 362 dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128] 363 for dtype_ in dtypes_to_test: 364 for use_placeholder_ in set([False, True]): 365 for fast_ in [True, False]: 366 l2_regularizers = [0] if dtype_ == np.complex128 else [0, 0.1] 367 for l2_regularizer_ in l2_regularizers: 368 for test_case in _GetSmallMatrixSolveLsOpTests( 369 dtype_, use_placeholder_, fast_, l2_regularizer_): 370 name = "%s_%s_placeholder_%s_fast_%s_regu_%s" % (test_case.__name__, 371 dtype_.__name__, 372 use_placeholder_, 373 fast_, 374 l2_regularizer_) 375 _AddTest(MatrixSolveLsOpTest, "MatrixSolveLsOpTest", name, 376 test_case) 377 for dtype_ in dtypes_to_test: 378 for test_case in _GetLargeMatrixSolveLsOpTests(dtype_, False, True, 0.0): 379 name = "%s_%s" % (test_case.__name__, dtype_.__name__) 380 _AddTest(MatrixSolveLsOpTest, "MatrixSolveLsOpTest", name, test_case) 381 382 test_lib.main() 383