1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import test_util 19from tensorflow.python.ops import math_ops 20from tensorflow.python.platform import test 21 22 23class TraceTest(test.TestCase): 24 25 def setUp(self): 26 x = np.random.seed(0) 27 28 def compare(self, x): 29 np_ans = np.trace(x, axis1=-2, axis2=-1) 30 with self.cached_session(): 31 tf_ans = math_ops.trace(x).eval() 32 self.assertAllClose(tf_ans, np_ans) 33 34 @test_util.run_deprecated_v1 35 def testTrace(self): 36 for dtype in [np.int32, np.float32, np.float64]: 37 for shape in [[2, 2], [2, 3], [3, 2], [2, 3, 2], [2, 2, 2, 3]]: 38 x = np.random.rand(np.prod(shape)).astype(dtype).reshape(shape) 39 self.compare(x) 40 41 42if __name__ == "__main__": 43 test.main() 44