1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.eager import backprop 19from tensorflow.python.eager import context 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import errors 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gradients_impl 25from tensorflow.python.ops import nn_ops 26import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 27from tensorflow.python.platform import test 28 29 30class NthElementTest(test.TestCase): 31 32 def _validateNthElement(self, inputs, dtype, n, reverse, expected_values): 33 np_expected_values = np.array(expected_values) 34 with self.cached_session(use_gpu=False) as sess: 35 inputs_op = ops.convert_to_tensor(inputs, dtype=dtype) 36 values_op = nn_ops.nth_element(inputs_op, n, reverse=reverse) 37 values = self.evaluate(values_op) 38 39 self.assertShapeEqual(np_expected_values, values_op) 40 self.assertAllClose(np_expected_values, values) 41 42 def testExample1(self): 43 inputs = [2.2, 4.4, 1.1, 5.5, 3.3] 44 self._validateNthElement(inputs, dtypes.float32, 1, False, 2.2) 45 self._validateNthElement(inputs, dtypes.float32, 1, True, 4.4) 46 47 def testExample2(self): 48 inputs = [[2.2, 4.4, 1.1], [5.5, 3.3, 6.6]] 49 self._validateNthElement(inputs, dtypes.float64, 2, False, [4.4, 6.6]) 50 self._validateNthElement(inputs, dtypes.float64, 2, True, [1.1, 3.3]) 51 52 def testExample3(self): 53 inputs = [[[2, 4, 1], [5, -3, 6]], 54 [[7, 9, -8], [9, 0, 4]]] 55 self._validateNthElement(inputs, dtypes.int32, 0, False, 56 [[1, -3], [-8, 0]]) 57 self._validateNthElement(inputs, dtypes.int64, 0, True, 58 [[4, 6], [9, 9]]) 59 60 def _testFloatLargeInput(self, input_shape): 61 inputs = np.random.random_sample(input_shape) 62 n = np.random.randint(input_shape[-1]) 63 sort_inputs = np.sort(inputs) 64 expected_values = sort_inputs[..., n] 65 self._validateNthElement( 66 inputs, dtypes.float32, n, False, expected_values) 67 expected_values = sort_inputs[..., ::-1][..., n] 68 self._validateNthElement( 69 inputs, dtypes.float64, n, True, expected_values) 70 71 def _testIntLargeInput(self, input_shape): 72 inputs = np.random.randint(-1e3, 1e3, input_shape) 73 n = np.random.randint(input_shape[-1]) 74 sort_inputs = np.sort(inputs) 75 expected_values = sort_inputs[..., n] 76 self._validateNthElement( 77 inputs, dtypes.int32, n, False, expected_values) 78 expected_values = sort_inputs[..., ::-1][..., n] 79 self._validateNthElement( 80 inputs, dtypes.int64, n, True, expected_values) 81 82 def _testLargeInput(self, input_shape): 83 self._testFloatLargeInput(input_shape) 84 self._testIntLargeInput(input_shape) 85 86 def testLargeInput(self): 87 self._testLargeInput([1]) 88 self._testLargeInput([10]) 89 self._testLargeInput([5, 10]) 90 self._testLargeInput([50, 100]) 91 self._testLargeInput([50, 10000]) 92 self._testLargeInput([50, 10, 100]) 93 self._testLargeInput([50, 10, 10, 100]) 94 95 def _testEnumerateN(self, input_shape): 96 inputs = np.random.random_sample(input_shape) 97 sort_inputs = np.sort(inputs) 98 for n in range(input_shape[-1]): 99 expected_values = sort_inputs[..., n] 100 self._validateNthElement( 101 inputs, dtypes.float32, n, False, expected_values) 102 expected_values = sort_inputs[..., ::-1][..., n] 103 self._validateNthElement( 104 inputs, dtypes.float64, n, True, expected_values) 105 106 def testEnumerateN(self): 107 self._testEnumerateN([1]) 108 self._testEnumerateN([10]) 109 self._testEnumerateN([10, 10]) 110 self._testEnumerateN([10, 10, 10]) 111 self._testEnumerateN([10, 10, 10, 10]) 112 113 def testInvalidInput(self): 114 with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 115 "at least rank 1 but is rank 0"): 116 nn_ops.nth_element(5, 0) 117 118 # Test with placeholders 119 with ops.Graph().as_default(): 120 with self.session(use_gpu=False): 121 v = array_ops.placeholder(dtype=dtypes.int32) 122 with self.assertRaisesOpError("at least rank 1 but is rank 0"): 123 nn_ops.nth_element(v, 0).eval(feed_dict={v: 5}) 124 125 def testInvalidN(self): 126 with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 127 "non-negative but is -1"): 128 nn_ops.nth_element([5], -1) 129 with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 130 "scalar but has rank 1"): 131 nn_ops.nth_element([5, 6, 3], [1]) 132 133 # Test with placeholders 134 with ops.Graph().as_default(): 135 with self.session(use_gpu=False): 136 n = array_ops.placeholder(dtypes.int32) 137 values = nn_ops.nth_element([5], n) 138 with self.assertRaisesOpError("non-negative but is -1"): 139 values.eval(feed_dict={n: -1}) 140 141 def testNTooLarge(self): 142 inputs = [[0.1, 0.2], [0.3, 0.4]] 143 with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 144 "must have last dimension > n = 2"): 145 nn_ops.nth_element(inputs, 2) 146 147 # Test with placeholders 148 with ops.Graph().as_default(): 149 with self.session(use_gpu=False): 150 n = array_ops.placeholder(dtypes.int32) 151 values = nn_ops.nth_element(inputs, n) 152 with self.assertRaisesOpError("must have last dimension > n = 2"): 153 values.eval(feed_dict={n: 2}) 154 155 def testGradients(self): 156 x = [ 157 [2., -1., 1000., 3., 1000.], 158 [1., 5., 2., 4., 3.], 159 [2., 2., 2., 2., 2.], 160 ] 161 grad_ys = [[-1., 2., 5.]] 162 result = [ 163 [0, 0, -0.5, 0, -0.5], 164 [0, 0, 0, 2, 0], 165 [1, 1, 1, 1, 1], 166 ] 167 if context.executing_eagerly(): 168 inputs = ops.convert_to_tensor(x) 169 with backprop.GradientTape() as tape: 170 tape.watch(inputs) 171 values = nn_ops.nth_element(inputs, 3) 172 grad = tape.gradient(values, inputs, ops.convert_to_tensor(grad_ys)) 173 self.assertAllClose(grad[0], result) 174 175 # Test with tf.gradients 176 with ops.Graph().as_default(): 177 with self.session(use_gpu=False) as sess: 178 inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5]) 179 values = nn_ops.nth_element(inputs, 3) 180 grad = sess.run( 181 gradients_impl.gradients(values, inputs, grad_ys=grad_ys), 182 feed_dict={inputs: x}) 183 self.assertAllClose(grad[0], result) 184 185 186 187if __name__ == "__main__": 188 test.main() 189