• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.eager import backprop
19from tensorflow.python.eager import context
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import errors
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gradients_impl
25from tensorflow.python.ops import nn_ops
26import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
27from tensorflow.python.platform import test
28
29
30class NthElementTest(test.TestCase):
31
32  def _validateNthElement(self, inputs, dtype, n, reverse, expected_values):
33    np_expected_values = np.array(expected_values)
34    with self.cached_session(use_gpu=False) as sess:
35      inputs_op = ops.convert_to_tensor(inputs, dtype=dtype)
36      values_op = nn_ops.nth_element(inputs_op, n, reverse=reverse)
37      values = self.evaluate(values_op)
38
39      self.assertShapeEqual(np_expected_values, values_op)
40      self.assertAllClose(np_expected_values, values)
41
42  def testExample1(self):
43    inputs = [2.2, 4.4, 1.1, 5.5, 3.3]
44    self._validateNthElement(inputs, dtypes.float32, 1, False, 2.2)
45    self._validateNthElement(inputs, dtypes.float32, 1, True, 4.4)
46
47  def testExample2(self):
48    inputs = [[2.2, 4.4, 1.1], [5.5, 3.3, 6.6]]
49    self._validateNthElement(inputs, dtypes.float64, 2, False, [4.4, 6.6])
50    self._validateNthElement(inputs, dtypes.float64, 2, True, [1.1, 3.3])
51
52  def testExample3(self):
53    inputs = [[[2, 4, 1], [5, -3, 6]],
54              [[7, 9, -8], [9, 0, 4]]]
55    self._validateNthElement(inputs, dtypes.int32, 0, False,
56                             [[1, -3], [-8, 0]])
57    self._validateNthElement(inputs, dtypes.int64, 0, True,
58                             [[4, 6], [9, 9]])
59
60  def _testFloatLargeInput(self, input_shape):
61    inputs = np.random.random_sample(input_shape)
62    n = np.random.randint(input_shape[-1])
63    sort_inputs = np.sort(inputs)
64    expected_values = sort_inputs[..., n]
65    self._validateNthElement(
66        inputs, dtypes.float32, n, False, expected_values)
67    expected_values = sort_inputs[..., ::-1][..., n]
68    self._validateNthElement(
69        inputs, dtypes.float64, n, True, expected_values)
70
71  def _testIntLargeInput(self, input_shape):
72    inputs = np.random.randint(-1e3, 1e3, input_shape)
73    n = np.random.randint(input_shape[-1])
74    sort_inputs = np.sort(inputs)
75    expected_values = sort_inputs[..., n]
76    self._validateNthElement(
77        inputs, dtypes.int32, n, False, expected_values)
78    expected_values = sort_inputs[..., ::-1][..., n]
79    self._validateNthElement(
80        inputs, dtypes.int64, n, True, expected_values)
81
82  def _testLargeInput(self, input_shape):
83    self._testFloatLargeInput(input_shape)
84    self._testIntLargeInput(input_shape)
85
86  def testLargeInput(self):
87    self._testLargeInput([1])
88    self._testLargeInput([10])
89    self._testLargeInput([5, 10])
90    self._testLargeInput([50, 100])
91    self._testLargeInput([50, 10000])
92    self._testLargeInput([50, 10, 100])
93    self._testLargeInput([50, 10, 10, 100])
94
95  def _testEnumerateN(self, input_shape):
96    inputs = np.random.random_sample(input_shape)
97    sort_inputs = np.sort(inputs)
98    for n in range(input_shape[-1]):
99      expected_values = sort_inputs[..., n]
100      self._validateNthElement(
101          inputs, dtypes.float32, n, False, expected_values)
102      expected_values = sort_inputs[..., ::-1][..., n]
103      self._validateNthElement(
104          inputs, dtypes.float64, n, True, expected_values)
105
106  def testEnumerateN(self):
107    self._testEnumerateN([1])
108    self._testEnumerateN([10])
109    self._testEnumerateN([10, 10])
110    self._testEnumerateN([10, 10, 10])
111    self._testEnumerateN([10, 10, 10, 10])
112
113  def testInvalidInput(self):
114    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
115                                "at least rank 1 but is rank 0"):
116      nn_ops.nth_element(5, 0)
117
118    # Test with placeholders
119    with ops.Graph().as_default():
120      with self.session(use_gpu=False):
121        v = array_ops.placeholder(dtype=dtypes.int32)
122        with self.assertRaisesOpError("at least rank 1 but is rank 0"):
123          nn_ops.nth_element(v, 0).eval(feed_dict={v: 5})
124
125  def testInvalidN(self):
126    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
127                                "non-negative but is -1"):
128      nn_ops.nth_element([5], -1)
129    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
130                                "scalar but has rank 1"):
131      nn_ops.nth_element([5, 6, 3], [1])
132
133    # Test with placeholders
134    with ops.Graph().as_default():
135      with self.session(use_gpu=False):
136        n = array_ops.placeholder(dtypes.int32)
137        values = nn_ops.nth_element([5], n)
138        with self.assertRaisesOpError("non-negative but is -1"):
139          values.eval(feed_dict={n: -1})
140
141  def testNTooLarge(self):
142    inputs = [[0.1, 0.2], [0.3, 0.4]]
143    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
144                                "must have last dimension > n = 2"):
145      nn_ops.nth_element(inputs, 2)
146
147    # Test with placeholders
148    with ops.Graph().as_default():
149      with self.session(use_gpu=False):
150        n = array_ops.placeholder(dtypes.int32)
151        values = nn_ops.nth_element(inputs, n)
152        with self.assertRaisesOpError("must have last dimension > n = 2"):
153          values.eval(feed_dict={n: 2})
154
155  def testGradients(self):
156    x = [
157        [2., -1., 1000., 3., 1000.],
158        [1., 5., 2., 4., 3.],
159        [2., 2., 2., 2., 2.],
160    ]
161    grad_ys = [[-1., 2., 5.]]
162    result = [
163        [0, 0, -0.5, 0, -0.5],
164        [0, 0, 0, 2, 0],
165        [1, 1, 1, 1, 1],
166    ]
167    if context.executing_eagerly():
168      inputs = ops.convert_to_tensor(x)
169      with backprop.GradientTape() as tape:
170        tape.watch(inputs)
171        values = nn_ops.nth_element(inputs, 3)
172      grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys))
173      self.assertAllClose(grad[0], result)
174
175    # Test with tf.gradients
176    with ops.Graph().as_default():
177      with self.session(use_gpu=False) as sess:
178        inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
179        values = nn_ops.nth_element(inputs, 3)
180        grad = sess.run(
181            gradients_impl.gradients(values, inputs, grad_ys=grad_ys),
182            feed_dict={inputs: x})
183    self.assertAllClose(grad[0], result)
184
185
186
187if __name__ == "__main__":
188  test.main()
189