• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the experimental input pipeline ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
21from tensorflow.contrib.data.python.ops import unique
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors
25from tensorflow.python.platform import test
26from tensorflow.python.util import compat
27
28
29class UniqueDatasetTest(test.TestCase):
30
31  def _testSimpleHelper(self, dtype, test_cases):
32    """Test the `unique()` transformation on a list of test cases.
33
34    Args:
35      dtype: The `dtype` of the elements in each test case.
36      test_cases: A list of pairs of lists. The first component is the test
37        input that will be passed to the transformation; the second component
38        is the expected sequence of outputs from the transformation.
39    """
40
41    # The `current_test_case` will be updated when we loop over `test_cases`
42    # below; declare it here so that the generator can capture it once.
43    current_test_case = []
44    dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
45                                                 dtype).apply(unique.unique())
46    iterator = dataset.make_initializable_iterator()
47    next_element = iterator.get_next()
48
49    with self.test_session() as sess:
50      for test_case, expected in test_cases:
51        current_test_case = test_case
52        sess.run(iterator.initializer)
53        for element in expected:
54          if dtype == dtypes.string:
55            element = compat.as_bytes(element)
56          self.assertAllEqual(element, sess.run(next_element))
57        with self.assertRaises(errors.OutOfRangeError):
58          sess.run(next_element)
59
60  def testSimpleInt(self):
61    for dtype in [dtypes.int32, dtypes.int64]:
62      self._testSimpleHelper(dtype, [
63          ([], []),
64          ([1], [1]),
65          ([1, 1, 1, 1, 1, 1, 1], [1]),
66          ([1, 2, 3, 4], [1, 2, 3, 4]),
67          ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
68          ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
69          ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
70      ])
71
72  def testSimpleString(self):
73    self._testSimpleHelper(dtypes.string, [
74        ([], []),
75        (["hello"], ["hello"]),
76        (["hello", "hello", "hello"], ["hello"]),
77        (["hello", "world"], ["hello", "world"]),
78        (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
79    ])
80
81
82class UniqueSerializationTest(
83    dataset_serialization_test_base.DatasetSerializationTestBase):
84
85  def testUnique(self):
86
87    def build_dataset(num_elements, unique_elem_range):
88      return dataset_ops.Dataset.range(num_elements).map(
89          lambda x: x % unique_elem_range).apply(unique.unique())
90
91    self.run_core_tests(lambda: build_dataset(200, 100),
92                        lambda: build_dataset(40, 100), 100)
93
94
95if __name__ == "__main__":
96  test.main()
97