1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the `tf.data.experimental.{save,load}` operations.""" 16import os 17import shutil 18from absl.testing import parameterized 19import numpy as np 20from tensorflow.python.data.kernel_tests import checkpoint_test_base 21from tensorflow.python.data.kernel_tests import test_base 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import combinations 25from tensorflow.python.ops import variables 26from tensorflow.python.platform import test 27 28 29class IOTest(test_base.DatasetTestBase, parameterized.TestCase): 30 31 def setUp(self): 32 super(IOTest, self).setUp() 33 tmpdir = self.get_temp_dir() 34 tmpdir = os.path.join(tmpdir, "io_test") 35 os.mkdir(tmpdir) 36 self._test_dir = tmpdir 37 self._checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt") 38 os.mkdir(self._checkpoint_prefix) 39 self._save_dir = os.path.join(self.get_temp_dir(), "save") 40 os.mkdir(self._save_dir) 41 42 def tearDown(self): 43 super(IOTest, self).tearDown() 44 shutil.rmtree(self._test_dir) 45 shutil.rmtree(self._checkpoint_prefix) 46 shutil.rmtree(self._save_dir) 47 48 @combinations.generate( 49 combinations.times(test_base.eager_only_combinations(), 50 combinations.combine(compression=[None, "GZIP"]))) 51 def testBasic(self, compression): 52 dataset = dataset_ops.Dataset.range(42) 53 dataset.save(self._test_dir, compression=compression) 54 dataset2 = dataset_ops.Dataset.load( 55 self._test_dir, dataset.element_spec, compression=compression) 56 self.assertDatasetProduces(dataset2, range(42)) 57 58 @combinations.generate(test_base.eager_only_combinations()) 59 def testCardinality(self): 60 dataset = dataset_ops.Dataset.range(42) 61 dataset.save(self._test_dir) 62 dataset2 = dataset_ops.Dataset.load(self._test_dir, dataset.element_spec) 63 self.assertEqual(self.evaluate(dataset2.cardinality()), 42) 64 65 @combinations.generate(test_base.eager_only_combinations()) 66 def testCustomShardFunction(self): 67 dataset = dataset_ops.Dataset.range(42) 68 dataset.save(self._test_dir, shard_func=lambda x: x // 21) 69 dataset2 = dataset_ops.Dataset.load(self._test_dir, dataset.element_spec) 70 expected = [] 71 for i in range(21): 72 expected.extend([i, i + 21]) 73 self.assertDatasetProduces(dataset2, expected) 74 75 @combinations.generate(test_base.eager_only_combinations()) 76 def testCustomReaderFunction(self): 77 dataset = dataset_ops.Dataset.range(42) 78 dataset.save(self._test_dir, shard_func=lambda x: x % 7) 79 dataset2 = dataset_ops.Dataset.load( 80 self._test_dir, 81 dataset.element_spec, 82 reader_func=lambda x: x.flat_map(lambda y: y)) 83 expected = [] 84 for i in range(7): 85 expected.extend(range(i, 42, 7)) 86 self.assertDatasetProduces(dataset2, expected) 87 88 @combinations.generate( 89 combinations.times(test_base.eager_only_combinations(), 90 combinations.combine(compression=[None, "GZIP"]))) 91 def testSaveInsideFunction(self, compression): 92 dataset = dataset_ops.Dataset.range(42) 93 @def_function.function 94 def save_fn(): 95 dataset.save(self._test_dir, compression=compression) 96 save_fn() 97 dataset = dataset_ops.Dataset.load( 98 self._test_dir, dataset.element_spec, compression=compression) 99 self.assertDatasetProduces(dataset, range(42)) 100 101 @combinations.generate(test_base.eager_only_combinations()) 102 def testElementSpecOptional(self): 103 range_dataset = dataset_ops.Dataset.range(42) 104 dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2], 105 "b": [3, 4]}) 106 tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])) 107 dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset, 108 tuple_dataset)) 109 dataset.save(self._test_dir) 110 dataset_loaded = dataset_ops.Dataset.load(self._test_dir) 111 self.assertDatasetsEqual(dataset, dataset_loaded) 112 113 @combinations.generate(test_base.graph_only_combinations()) 114 def testElementSpecRequired(self): 115 dataset = dataset_ops.Dataset.range(42) 116 dataset.save(self._test_dir) 117 with self.assertRaises(ValueError): 118 _ = dataset_ops.Dataset.load(self._test_dir) 119 120 @combinations.generate(test_base.eager_only_combinations()) 121 def testRepeatAndPrefetch(self): 122 """This test reproduces github.com/tensorflow/tensorflow/issues/49165.""" 123 dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) 124 dataset1.save(self._test_dir) 125 dataset = dataset_ops.Dataset.load(self._test_dir) 126 dataset = dataset.shuffle(buffer_size=16) 127 dataset = dataset.batch(16) 128 dataset = dataset.repeat() 129 dataset = dataset.prefetch(1) 130 next_element = self.getNext(dataset) 131 for _ in range(30): 132 self.evaluate(next_element()) 133 134 135class LoadCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase): 136 137 def _build_ds(self): 138 return dataset_ops.Dataset.load(self._save_dir) 139 140 @combinations.generate( 141 combinations.times(test_base.eager_only_combinations(), 142 checkpoint_test_base.default_test_combinations())) 143 def test(self, verify_fn): 144 dataset = dataset_ops.Dataset.range(42) 145 dataset.save(self._save_dir) 146 verify_fn(self, self._build_ds, num_outputs=42) 147 148 149class SaveCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase): 150 151 @combinations.generate(test_base.eager_only_combinations()) 152 def testSaveCheckpointingAPI(self): 153 dataset = dataset_ops.Dataset.range(40) 154 checkpoint_args = {"directory": self._checkpoint_prefix, "max_to_keep": 50} 155 dataset.save(self._save_dir, checkpoint_args=checkpoint_args) 156 num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) 157 # By default, we checkpoint every increment. Each checkpoint writes a 158 # file containing the data and a file containing the index. There is 159 # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files. 160 self.assertEqual(81, num_checkpoint_files) 161 162 @combinations.generate(test_base.eager_only_combinations()) 163 def testSaveCheckpointingAPICustomCheckpointInterval(self): 164 dataset = dataset_ops.Dataset.range(40) 165 step_counter = variables.Variable(0, trainable=False) 166 checkpoint_args = { 167 "checkpoint_interval": 5, 168 "step_counter": step_counter, 169 "directory": self._checkpoint_prefix, 170 "max_to_keep": 10, 171 } 172 dataset.save(self._save_dir, checkpoint_args=checkpoint_args) 173 num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) 174 # We expect (2 * 8) + 1 files. 175 self.assertEqual(17, num_checkpoint_files) 176 177 @combinations.generate(test_base.eager_only_combinations()) 178 def testSaveCheckpointingAPIIncorrectArgs(self): 179 dataset = dataset_ops.Dataset.range(42) 180 checkpoint_args = { 181 "directory": self._checkpoint_prefix, 182 "incorrect_arg": "incorrect_arg" 183 } 184 with self.assertRaises(TypeError): 185 dataset.save( 186 dataset, self._save_dir, checkpoint_args=checkpoint_args) 187 188if __name__ == "__main__": 189 test.main() 190