1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.norm.""" 16 17import numpy as np 18 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import nn_impl 21from tensorflow.python.platform import test as test_lib 22 23 24def _AddTest(test, test_name, fn): 25 test_name = "_".join(["test", test_name]) 26 if hasattr(test, test_name): 27 raise RuntimeError("Test %s defined more than once" % test_name) 28 setattr(test, test_name, fn) 29 30 31# pylint: disable=redefined-builtin 32def _Normalize(x, ord, axis): 33 if isinstance(axis, (list, tuple)): 34 norm = np.linalg.norm(x, ord, tuple(axis)) 35 if axis[0] < axis[1]: 36 # This prevents axis to be inserted in-between 37 # e.g. when (-2, -1) 38 for d in reversed(axis): 39 norm = np.expand_dims(norm, d) 40 else: 41 for d in axis: 42 norm = np.expand_dims(norm, d) 43 return x / norm 44 elif axis is None: 45 # Tensorflow handles None differently 46 norm = np.linalg.norm(x.flatten(), ord, axis) 47 return x / norm 48 else: 49 norm = np.apply_along_axis(np.linalg.norm, axis, x, ord) 50 return x / np.expand_dims(norm, axis) 51 52 53class NormalizeOpTest(test_lib.TestCase): 54 pass 55 56 57def _GetNormalizeOpTest(dtype_, shape_, ord_, axis_): 58 59 @test_util.run_in_graph_and_eager_modes 60 def Test(self): 61 is_matrix_norm = (isinstance(axis_, tuple) or 62 isinstance(axis_, list)) and len(axis_) == 2 63 is_fancy_p_norm = np.isreal(ord_) and np.floor(ord_) != ord_ 64 if ((not is_matrix_norm and ord_ == "fro") or 65 (is_matrix_norm and is_fancy_p_norm)): 66 self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm") 67 if ord_ == "euclidean" or (axis_ is None and len(shape) > 2): 68 self.skipTest("Not supported by numpy.linalg.norm") 69 matrix = np.random.randn(*shape_).astype(dtype_) 70 if dtype_ in (np.complex64, np.complex128): 71 matrix += 1j * np.random.randn(*shape_).astype(dtype_) 72 tf_np_n, _ = self.evaluate(nn_impl.normalize(matrix, ord_, axis_)) 73 np_n = _Normalize(matrix, ord_, axis_) 74 self.assertAllClose(tf_np_n, np_n, rtol=1e-5, atol=1e-5) 75 76 return Test 77 78 79# pylint: disable=redefined-builtin 80if __name__ == "__main__": 81 for dtype in np.float32, np.float64, np.complex64, np.complex128: 82 for rows in 2, 5: 83 for cols in 2, 5: 84 for batch in [], [2], [2, 3]: 85 shape = batch + [rows, cols] 86 for ord in "euclidean", "fro", 0.5, 1, 2, np.inf: 87 for axis in [ 88 None, (-2, -1), (-1, -2), -len(shape), 0, 89 len(shape) - 1 90 ]: 91 name = "%s_%s_ord_%s_axis_%s" % (dtype.__name__, "_".join( 92 map(str, shape)), ord, axis) 93 _AddTest(NormalizeOpTest, "Normalize_" + name, 94 _GetNormalizeOpTest(dtype, shape, ord, axis)) 95 96 test_lib.main() 97