1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the experimental input pipeline ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 21from tensorflow.contrib.data.python.ops import unique 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import errors 25from tensorflow.python.platform import test 26from tensorflow.python.util import compat 27 28 29class UniqueDatasetTest(test.TestCase): 30 31 def _testSimpleHelper(self, dtype, test_cases): 32 """Test the `unique()` transformation on a list of test cases. 33 34 Args: 35 dtype: The `dtype` of the elements in each test case. 36 test_cases: A list of pairs of lists. The first component is the test 37 input that will be passed to the transformation; the second component 38 is the expected sequence of outputs from the transformation. 39 """ 40 41 # The `current_test_case` will be updated when we loop over `test_cases` 42 # below; declare it here so that the generator can capture it once. 43 current_test_case = [] 44 dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case, 45 dtype).apply(unique.unique()) 46 iterator = dataset.make_initializable_iterator() 47 next_element = iterator.get_next() 48 49 with self.test_session() as sess: 50 for test_case, expected in test_cases: 51 current_test_case = test_case 52 sess.run(iterator.initializer) 53 for element in expected: 54 if dtype == dtypes.string: 55 element = compat.as_bytes(element) 56 self.assertAllEqual(element, sess.run(next_element)) 57 with self.assertRaises(errors.OutOfRangeError): 58 sess.run(next_element) 59 60 def testSimpleInt(self): 61 for dtype in [dtypes.int32, dtypes.int64]: 62 self._testSimpleHelper(dtype, [ 63 ([], []), 64 ([1], [1]), 65 ([1, 1, 1, 1, 1, 1, 1], [1]), 66 ([1, 2, 3, 4], [1, 2, 3, 4]), 67 ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]), 68 ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]), 69 ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), 70 ]) 71 72 def testSimpleString(self): 73 self._testSimpleHelper(dtypes.string, [ 74 ([], []), 75 (["hello"], ["hello"]), 76 (["hello", "hello", "hello"], ["hello"]), 77 (["hello", "world"], ["hello", "world"]), 78 (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]), 79 ]) 80 81 82class UniqueSerializationTest( 83 dataset_serialization_test_base.DatasetSerializationTestBase): 84 85 def testUnique(self): 86 87 def build_dataset(num_elements, unique_elem_range): 88 return dataset_ops.Dataset.range(num_elements).map( 89 lambda x: x % unique_elem_range).apply(unique.unique()) 90 91 self.run_core_tests(lambda: build_dataset(200, 100), 92 lambda: build_dataset(40, 100), 100) 93 94 95if __name__ == "__main__": 96 test.main() 97