• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities for traversing the dataset construction graph."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.compat import compat
20from tensorflow.python.data.experimental.ops import data_service_ops
21from tensorflow.python.data.kernel_tests import test_base
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import traverse
24from tensorflow.python.framework import combinations
25from tensorflow.python.ops import gen_dataset_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import test
28
29
30class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
31
32  def __init__(self, input_dataset):
33    self._input_dataset = input_dataset
34    temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
35        input_dataset._variant_tensor,
36        buffer_size=1,
37        **self._flat_structure)
38    variant_tensor = gen_dataset_ops.model_dataset(
39        temp_variant_tensor, **self._flat_structure)
40    super(_TestDataset, self).__init__(input_dataset, variant_tensor)
41
42
43class TraverseTest(test_base.DatasetTestBase, parameterized.TestCase):
44
45  @combinations.generate(test_base.graph_only_combinations())
46  def testOnlySource(self):
47    ds = dataset_ops.Dataset.range(10)
48    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
49    self.assertAllEqual(["RangeDataset"], [x.name for x in variant_tensor_ops])
50
51  @combinations.generate(test_base.graph_only_combinations())
52  def testSimplePipeline(self):
53    ds = dataset_ops.Dataset.range(10).map(math_ops.square)
54    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
55    self.assertSetEqual(
56        set(["MapDataset", "RangeDataset"]),
57        set(x.name for x in variant_tensor_ops))
58
59  @combinations.generate(test_base.graph_only_combinations())
60  def testConcat(self):
61    ds1 = dataset_ops.Dataset.range(10)
62    ds2 = dataset_ops.Dataset.range(10)
63    ds = ds1.concatenate(ds2)
64    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
65    self.assertSetEqual(
66        set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]),
67        set(x.name for x in variant_tensor_ops))
68
69  @combinations.generate(test_base.graph_only_combinations())
70  def testZip(self):
71    ds1 = dataset_ops.Dataset.range(10)
72    ds2 = dataset_ops.Dataset.range(10)
73    ds = dataset_ops.Dataset.zip((ds1, ds2))
74    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
75    self.assertSetEqual(
76        set(["ZipDataset", "RangeDataset", "RangeDataset_1"]),
77        set(x.name for x in variant_tensor_ops))
78
79  @combinations.generate(test_base.graph_only_combinations())
80  def testMultipleVariantTensors(self):
81    ds = dataset_ops.Dataset.range(10)
82    ds = _TestDataset(ds)
83    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
84    self.assertSetEqual(
85        set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
86        set(x.name for x in variant_tensor_ops))
87
88  @combinations.generate(test_base.graph_only_combinations())
89  def testFlatMap(self):
90    ds1 = dataset_ops.Dataset.range(10).repeat(10)
91
92    def map_fn(ds):
93
94      def _map(x):
95        return ds.batch(x)
96
97      return _map
98
99    ds2 = dataset_ops.Dataset.range(20).prefetch(1)
100    ds2 = ds2.flat_map(map_fn(ds1))
101    variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2)
102    self.assertSetEqual(
103        set([
104            "FlatMapDataset", "PrefetchDataset", "RepeatDataset",
105            "RangeDataset", "RangeDataset_1"
106        ]), set(x.name for x in variant_tensor_ops))
107
108  @combinations.generate(test_base.graph_only_combinations())
109  def testTfDataService(self):
110    ds = dataset_ops.Dataset.range(10)
111    ds = ds.apply(
112        data_service_ops.distribute("parallel_epochs", "grpc://foo:0"))
113    ops = traverse.obtain_capture_by_value_ops(ds)
114    data_service_dataset_op = ("DataServiceDatasetV4"
115                               if compat.forward_compatible(2022, 8, 31) else
116                               "DataServiceDatasetV3")
117    self.assertContainsSubset(
118        ["RangeDataset", data_service_dataset_op, "DummyIterationCounter"],
119        set(x.name for x in ops))
120
121
122if __name__ == "__main__":
123  test.main()
124