• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the `tf.data.experimental.{save,load}` operations."""
16import os
17import shutil
18from absl.testing import parameterized
19import numpy as np
20from tensorflow.python.data.kernel_tests import checkpoint_test_base
21from tensorflow.python.data.kernel_tests import test_base
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import combinations
25from tensorflow.python.ops import variables
26from tensorflow.python.platform import test
27
28
29class IOTest(test_base.DatasetTestBase, parameterized.TestCase):
30
31  def setUp(self):
32    super(IOTest, self).setUp()
33    tmpdir = self.get_temp_dir()
34    tmpdir = os.path.join(tmpdir, "io_test")
35    os.mkdir(tmpdir)
36    self._test_dir = tmpdir
37    self._checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt")
38    os.mkdir(self._checkpoint_prefix)
39    self._save_dir = os.path.join(self.get_temp_dir(), "save")
40    os.mkdir(self._save_dir)
41
42  def tearDown(self):
43    super(IOTest, self).tearDown()
44    shutil.rmtree(self._test_dir)
45    shutil.rmtree(self._checkpoint_prefix)
46    shutil.rmtree(self._save_dir)
47
48  @combinations.generate(
49      combinations.times(test_base.eager_only_combinations(),
50                         combinations.combine(compression=[None, "GZIP"])))
51  def testBasic(self, compression):
52    dataset = dataset_ops.Dataset.range(42)
53    dataset.save(self._test_dir, compression=compression)
54    dataset2 = dataset_ops.Dataset.load(
55        self._test_dir, dataset.element_spec, compression=compression)
56    self.assertDatasetProduces(dataset2, range(42))
57
58  @combinations.generate(test_base.eager_only_combinations())
59  def testCardinality(self):
60    dataset = dataset_ops.Dataset.range(42)
61    dataset.save(self._test_dir)
62    dataset2 = dataset_ops.Dataset.load(self._test_dir, dataset.element_spec)
63    self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
64
65  @combinations.generate(test_base.eager_only_combinations())
66  def testCustomShardFunction(self):
67    dataset = dataset_ops.Dataset.range(42)
68    dataset.save(self._test_dir, shard_func=lambda x: x // 21)
69    dataset2 = dataset_ops.Dataset.load(self._test_dir, dataset.element_spec)
70    expected = []
71    for i in range(21):
72      expected.extend([i, i + 21])
73    self.assertDatasetProduces(dataset2, expected)
74
75  @combinations.generate(test_base.eager_only_combinations())
76  def testCustomReaderFunction(self):
77    dataset = dataset_ops.Dataset.range(42)
78    dataset.save(self._test_dir, shard_func=lambda x: x % 7)
79    dataset2 = dataset_ops.Dataset.load(
80        self._test_dir,
81        dataset.element_spec,
82        reader_func=lambda x: x.flat_map(lambda y: y))
83    expected = []
84    for i in range(7):
85      expected.extend(range(i, 42, 7))
86    self.assertDatasetProduces(dataset2, expected)
87
88  @combinations.generate(
89      combinations.times(test_base.eager_only_combinations(),
90                         combinations.combine(compression=[None, "GZIP"])))
91  def testSaveInsideFunction(self, compression):
92    dataset = dataset_ops.Dataset.range(42)
93    @def_function.function
94    def save_fn():
95      dataset.save(self._test_dir, compression=compression)
96    save_fn()
97    dataset = dataset_ops.Dataset.load(
98        self._test_dir, dataset.element_spec, compression=compression)
99    self.assertDatasetProduces(dataset, range(42))
100
101  @combinations.generate(test_base.eager_only_combinations())
102  def testElementSpecOptional(self):
103    range_dataset = dataset_ops.Dataset.range(42)
104    dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2],
105                                                           "b": [3, 4]})
106    tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4]))
107    dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset,
108                                       tuple_dataset))
109    dataset.save(self._test_dir)
110    dataset_loaded = dataset_ops.Dataset.load(self._test_dir)
111    self.assertDatasetsEqual(dataset, dataset_loaded)
112
113  @combinations.generate(test_base.graph_only_combinations())
114  def testElementSpecRequired(self):
115    dataset = dataset_ops.Dataset.range(42)
116    dataset.save(self._test_dir)
117    with self.assertRaises(ValueError):
118      _ = dataset_ops.Dataset.load(self._test_dir)
119
120  @combinations.generate(test_base.eager_only_combinations())
121  def testRepeatAndPrefetch(self):
122    """This test reproduces github.com/tensorflow/tensorflow/issues/49165."""
123    dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
124    dataset1.save(self._test_dir)
125    dataset = dataset_ops.Dataset.load(self._test_dir)
126    dataset = dataset.shuffle(buffer_size=16)
127    dataset = dataset.batch(16)
128    dataset = dataset.repeat()
129    dataset = dataset.prefetch(1)
130    next_element = self.getNext(dataset)
131    for _ in range(30):
132      self.evaluate(next_element())
133
134
135class LoadCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase):
136
137  def _build_ds(self):
138    return dataset_ops.Dataset.load(self._save_dir)
139
140  @combinations.generate(
141      combinations.times(test_base.eager_only_combinations(),
142                         checkpoint_test_base.default_test_combinations()))
143  def test(self, verify_fn):
144    dataset = dataset_ops.Dataset.range(42)
145    dataset.save(self._save_dir)
146    verify_fn(self, self._build_ds, num_outputs=42)
147
148
149class SaveCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase):
150
151  @combinations.generate(test_base.eager_only_combinations())
152  def testSaveCheckpointingAPI(self):
153    dataset = dataset_ops.Dataset.range(40)
154    checkpoint_args = {"directory": self._checkpoint_prefix, "max_to_keep": 50}
155    dataset.save(self._save_dir, checkpoint_args=checkpoint_args)
156    num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
157    # By default, we checkpoint every increment. Each checkpoint writes a
158    # file containing the data and a file containing the index. There is
159    # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files.
160    self.assertEqual(81, num_checkpoint_files)
161
162  @combinations.generate(test_base.eager_only_combinations())
163  def testSaveCheckpointingAPICustomCheckpointInterval(self):
164    dataset = dataset_ops.Dataset.range(40)
165    step_counter = variables.Variable(0, trainable=False)
166    checkpoint_args = {
167        "checkpoint_interval": 5,
168        "step_counter": step_counter,
169        "directory": self._checkpoint_prefix,
170        "max_to_keep": 10,
171    }
172    dataset.save(self._save_dir, checkpoint_args=checkpoint_args)
173    num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
174    # We expect (2 * 8) + 1 files.
175    self.assertEqual(17, num_checkpoint_files)
176
177  @combinations.generate(test_base.eager_only_combinations())
178  def testSaveCheckpointingAPIIncorrectArgs(self):
179    dataset = dataset_ops.Dataset.range(42)
180    checkpoint_args = {
181        "directory": self._checkpoint_prefix,
182        "incorrect_arg": "incorrect_arg"
183    }
184    with self.assertRaises(TypeError):
185      dataset.save(
186          dataset, self._save_dir, checkpoint_args=checkpoint_args)
187
188if __name__ == "__main__":
189  test.main()
190