1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities for traversing the dataset construction graph.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.compat import compat 20from tensorflow.python.data.experimental.ops import data_service_ops 21from tensorflow.python.data.kernel_tests import test_base 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.util import traverse 24from tensorflow.python.framework import combinations 25from tensorflow.python.ops import gen_dataset_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.platform import test 28 29 30class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset): 31 32 def __init__(self, input_dataset): 33 self._input_dataset = input_dataset 34 temp_variant_tensor = gen_dataset_ops.prefetch_dataset( 35 input_dataset._variant_tensor, 36 buffer_size=1, 37 **self._flat_structure) 38 variant_tensor = gen_dataset_ops.model_dataset( 39 temp_variant_tensor, **self._flat_structure) 40 super(_TestDataset, self).__init__(input_dataset, variant_tensor) 41 42 43class TraverseTest(test_base.DatasetTestBase, parameterized.TestCase): 44 45 @combinations.generate(test_base.graph_only_combinations()) 46 def testOnlySource(self): 47 ds = dataset_ops.Dataset.range(10) 48 variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) 49 self.assertAllEqual(["RangeDataset"], [x.name for x in variant_tensor_ops]) 50 51 @combinations.generate(test_base.graph_only_combinations()) 52 def testSimplePipeline(self): 53 ds = dataset_ops.Dataset.range(10).map(math_ops.square) 54 variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) 55 self.assertSetEqual( 56 set(["MapDataset", "RangeDataset"]), 57 set(x.name for x in variant_tensor_ops)) 58 59 @combinations.generate(test_base.graph_only_combinations()) 60 def testConcat(self): 61 ds1 = dataset_ops.Dataset.range(10) 62 ds2 = dataset_ops.Dataset.range(10) 63 ds = ds1.concatenate(ds2) 64 variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) 65 self.assertSetEqual( 66 set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]), 67 set(x.name for x in variant_tensor_ops)) 68 69 @combinations.generate(test_base.graph_only_combinations()) 70 def testZip(self): 71 ds1 = dataset_ops.Dataset.range(10) 72 ds2 = dataset_ops.Dataset.range(10) 73 ds = dataset_ops.Dataset.zip((ds1, ds2)) 74 variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) 75 self.assertSetEqual( 76 set(["ZipDataset", "RangeDataset", "RangeDataset_1"]), 77 set(x.name for x in variant_tensor_ops)) 78 79 @combinations.generate(test_base.graph_only_combinations()) 80 def testMultipleVariantTensors(self): 81 ds = dataset_ops.Dataset.range(10) 82 ds = _TestDataset(ds) 83 variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds) 84 self.assertSetEqual( 85 set(["RangeDataset", "ModelDataset", "PrefetchDataset"]), 86 set(x.name for x in variant_tensor_ops)) 87 88 @combinations.generate(test_base.graph_only_combinations()) 89 def testFlatMap(self): 90 ds1 = dataset_ops.Dataset.range(10).repeat(10) 91 92 def map_fn(ds): 93 94 def _map(x): 95 return ds.batch(x) 96 97 return _map 98 99 ds2 = dataset_ops.Dataset.range(20).prefetch(1) 100 ds2 = ds2.flat_map(map_fn(ds1)) 101 variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2) 102 self.assertSetEqual( 103 set([ 104 "FlatMapDataset", "PrefetchDataset", "RepeatDataset", 105 "RangeDataset", "RangeDataset_1" 106 ]), set(x.name for x in variant_tensor_ops)) 107 108 @combinations.generate(test_base.graph_only_combinations()) 109 def testTfDataService(self): 110 ds = dataset_ops.Dataset.range(10) 111 ds = ds.apply( 112 data_service_ops.distribute("parallel_epochs", "grpc://foo:0")) 113 ops = traverse.obtain_capture_by_value_ops(ds) 114 data_service_dataset_op = ("DataServiceDatasetV4" 115 if compat.forward_compatible(2022, 8, 31) else 116 "DataServiceDatasetV3") 117 self.assertContainsSubset( 118 ["RangeDataset", data_service_dataset_op, "DummyIterationCounter"], 119 set(x.name for x in ops)) 120 121 122if __name__ == "__main__": 123 test.main() 124