• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.norm."""
16
17import numpy as np
18
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import nn_impl
21from tensorflow.python.platform import test as test_lib
22
23
24def _AddTest(test, test_name, fn):
25  test_name = "_".join(["test", test_name])
26  if hasattr(test, test_name):
27    raise RuntimeError("Test %s defined more than once" % test_name)
28  setattr(test, test_name, fn)
29
30
31# pylint: disable=redefined-builtin
32def _Normalize(x, ord, axis):
33  if isinstance(axis, (list, tuple)):
34    norm = np.linalg.norm(x, ord, tuple(axis))
35    if axis[0] < axis[1]:
36      # This prevents axis to be inserted in-between
37      # e.g. when (-2, -1)
38      for d in reversed(axis):
39        norm = np.expand_dims(norm, d)
40    else:
41      for d in axis:
42        norm = np.expand_dims(norm, d)
43    return x / norm
44  elif axis is None:
45    # Tensorflow handles None differently
46    norm = np.linalg.norm(x.flatten(), ord, axis)
47    return x / norm
48  else:
49    norm = np.apply_along_axis(np.linalg.norm, axis, x, ord)
50    return x / np.expand_dims(norm, axis)
51
52
53class NormalizeOpTest(test_lib.TestCase):
54  pass
55
56
57def _GetNormalizeOpTest(dtype_, shape_, ord_, axis_):
58
59  @test_util.run_in_graph_and_eager_modes
60  def Test(self):
61    is_matrix_norm = (isinstance(axis_, tuple) or
62                      isinstance(axis_, list)) and len(axis_) == 2
63    is_fancy_p_norm = np.isreal(ord_) and np.floor(ord_) != ord_
64    if ((not is_matrix_norm and ord_ == "fro") or
65        (is_matrix_norm and is_fancy_p_norm)):
66      self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
67    if ord_ == "euclidean" or (axis_ is None and len(shape) > 2):
68      self.skipTest("Not supported by numpy.linalg.norm")
69    matrix = np.random.randn(*shape_).astype(dtype_)
70    if dtype_ in (np.complex64, np.complex128):
71      matrix += 1j * np.random.randn(*shape_).astype(dtype_)
72    tf_np_n, _ = self.evaluate(nn_impl.normalize(matrix, ord_, axis_))
73    np_n = _Normalize(matrix, ord_, axis_)
74    self.assertAllClose(tf_np_n, np_n, rtol=1e-5, atol=1e-5)
75
76  return Test
77
78
79# pylint: disable=redefined-builtin
80if __name__ == "__main__":
81  for dtype in np.float32, np.float64, np.complex64, np.complex128:
82    for rows in 2, 5:
83      for cols in 2, 5:
84        for batch in [], [2], [2, 3]:
85          shape = batch + [rows, cols]
86          for ord in "euclidean", "fro", 0.5, 1, 2, np.inf:
87            for axis in [
88                None, (-2, -1), (-1, -2), -len(shape), 0,
89                len(shape) - 1
90            ]:
91              name = "%s_%s_ord_%s_axis_%s" % (dtype.__name__, "_".join(
92                  map(str, shape)), ord, axis)
93              _AddTest(NormalizeOpTest, "Normalize_" + name,
94                       _GetNormalizeOpTest(dtype, shape, ord, axis))
95
96  test_lib.main()
97