• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.gen_linalg_ops.matrix_logarithm."""
16
17import numpy as np
18
19from tensorflow.python.client import session
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import errors_impl
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import gen_linalg_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import stateless_random_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.ops.linalg import linalg_impl
31from tensorflow.python.platform import benchmark
32from tensorflow.python.platform import test
33
34
35class LogarithmOpTest(test.TestCase):
36
37  def _verifyLogarithm(self, x, np_type):
38    inp = x.astype(np_type)
39    with test_util.use_gpu():
40      # Verify that expm(logm(A)) == A.
41      tf_ans = linalg_impl.matrix_exponential(
42          gen_linalg_ops.matrix_logarithm(inp))
43      out = self.evaluate(tf_ans)
44      self.assertAllClose(inp, out, rtol=1e-4, atol=1e-3)
45
46  def _verifyLogarithmComplex(self, x):
47    for np_type in [np.complex64, np.complex128]:
48      self._verifyLogarithm(x, np_type)
49
50  def _makeBatch(self, matrix1, matrix2):
51    matrix_batch = np.concatenate(
52        [np.expand_dims(matrix1, 0),
53         np.expand_dims(matrix2, 0)])
54    matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
55    return matrix_batch
56
57  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
58  def testNonsymmetric(self):
59    # 2x2 matrices
60    matrix1 = np.array([[1., 2.], [3., 4.]])
61    matrix2 = np.array([[1., 3.], [3., 5.]])
62    matrix1 = matrix1.astype(np.complex64)
63    matrix1 += 1j * matrix1
64    matrix2 = matrix2.astype(np.complex64)
65    matrix2 += 1j * matrix2
66    self._verifyLogarithmComplex(matrix1)
67    self._verifyLogarithmComplex(matrix2)
68    # Complex batch
69    self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2))
70
71  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
72  def testSymmetricPositiveDefinite(self):
73    # 2x2 matrices
74    matrix1 = np.array([[2., 1.], [1., 2.]])
75    matrix2 = np.array([[3., -1.], [-1., 3.]])
76    matrix1 = matrix1.astype(np.complex64)
77    matrix1 += 1j * matrix1
78    matrix2 = matrix2.astype(np.complex64)
79    matrix2 += 1j * matrix2
80    self._verifyLogarithmComplex(matrix1)
81    self._verifyLogarithmComplex(matrix2)
82    # Complex batch
83    self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2))
84
85  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
86  def testNonSquareMatrix(self):
87    # When the logarithm of a non-square matrix is attempted we should return
88    # an error
89    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
90      gen_linalg_ops.matrix_logarithm(
91          np.array([[1., 2., 3.], [3., 4., 5.]], dtype=np.complex64))
92
93  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
94  def testWrongDimensions(self):
95    # The input to the logarithm should be at least a 2-dimensional tensor.
96    tensor3 = constant_op.constant([1., 2.], dtype=dtypes.complex64)
97    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
98      gen_linalg_ops.matrix_logarithm(tensor3)
99
100  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
101  def testEmpty(self):
102    self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64))
103    self._verifyLogarithmComplex(np.empty([2, 0, 0], dtype=np.complex64))
104
105  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
106  def testRandomSmallAndLargeComplex64(self):
107    np.random.seed(42)
108    for batch_dims in [(), (1,), (3,), (2, 2)]:
109      for size in 8, 31, 32:
110        shape = batch_dims + (size, size)
111        matrix = np.random.uniform(
112            low=-1.0, high=1.0,
113            size=np.prod(shape)).reshape(shape).astype(np.complex64)
114        self._verifyLogarithmComplex(matrix)
115
116  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
117  def testRandomSmallAndLargeComplex128(self):
118    np.random.seed(42)
119    for batch_dims in [(), (1,), (3,), (2, 2)]:
120      for size in 8, 31, 32:
121        shape = batch_dims + (size, size)
122        matrix = np.random.uniform(
123            low=-1.0, high=1.0,
124            size=np.prod(shape)).reshape(shape).astype(np.complex128)
125        self._verifyLogarithmComplex(matrix)
126
127  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
128  def testConcurrentExecutesWithoutError(self):
129    matrix_shape = [5, 5]
130    seed = [42, 24]
131    matrix1 = math_ops.cast(
132        stateless_random_ops.stateless_random_normal(matrix_shape, seed=seed),
133        dtypes.complex64)
134    matrix2 = math_ops.cast(
135        stateless_random_ops.stateless_random_normal(matrix_shape, seed=seed),
136        dtypes.complex64)
137    self.assertAllEqual(matrix1, matrix2)
138    logm1 = gen_linalg_ops.matrix_logarithm(matrix1)
139    logm2 = gen_linalg_ops.matrix_logarithm(matrix2)
140    logm = self.evaluate([logm1, logm2])
141    self.assertAllEqual(logm[0], logm[1])
142
143
144class MatrixLogarithmBenchmark(test.Benchmark):
145
146  shapes = [
147      (4, 4),
148      (10, 10),
149      (16, 16),
150      (101, 101),
151      (256, 256),
152      (1000, 1000),
153      (1024, 1024),
154      (2048, 2048),
155      (513, 4, 4),
156      (513, 16, 16),
157      (513, 256, 256),
158  ]
159
160  def _GenerateMatrix(self, shape):
161    batch_shape = shape[:-2]
162    shape = shape[-2:]
163    assert shape[0] == shape[1]
164    n = shape[0]
165    matrix = np.ones(shape).astype(np.complex64) / (2.0 * n) + np.diag(
166        np.ones(n).astype(np.complex64))
167    return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
168
169  def benchmarkMatrixLogarithmOp(self):
170    for shape in self.shapes:
171      with ops.Graph().as_default(), \
172          session.Session(config=benchmark.benchmark_config()) as sess, \
173          ops.device("/cpu:0"):
174        matrix = self._GenerateMatrix(shape)
175        logm = gen_linalg_ops.matrix_logarithm(matrix)
176        self.evaluate(variables.global_variables_initializer())
177        self.run_op_benchmark(
178            sess,
179            control_flow_ops.group(logm),
180            min_iters=25,
181            name="matrix_logarithm_cpu_{shape}".format(shape=shape))
182
183
184if __name__ == "__main__":
185  test.main()
186