• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for input pipeline modifications for distribution strategies."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import readers
25from tensorflow.python.data.util import structure
26from tensorflow.python.distribute import input_ops
27from tensorflow.python.eager import context
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import test_util
30from tensorflow.python.lib.io import python_io
31from tensorflow.python.ops import gen_dataset_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.platform import test
34from tensorflow.python.util import compat
35
36
37class AutoShardDatasetTest(test.TestCase):
38
39  def setUp(self):
40    super(AutoShardDatasetTest, self).setUp()
41    self._num_files = 10
42    self._num_records = 4
43    self._num_shards = 2
44    self._shard_index = 0
45    self._record_bytes = 10
46
47  def _getNext(self, dataset):
48    if context.executing_eagerly():
49      iterator = iter(dataset)
50      return iterator._next_internal  # pylint: disable=protected-access
51    else:
52      iterator = dataset_ops.make_one_shot_iterator(dataset)
53      get_next = iterator.get_next()
54      return lambda: get_next
55
56  def _record(self, r, f):
57    return compat.as_bytes("Record %d of file %d" % (r, f))
58
59  def _text_line(self, r, f):
60    return compat.as_bytes("Text line %d of file %d" % (r, f))
61
62  def _fixed_length_record(self, r, f):
63    return compat.as_bytes(str((r * f) % 10) * self._record_bytes)
64
65  def _createTFRecordFiles(self):
66    filenames = []
67    for i in range(self._num_files):
68      fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
69      filenames.append(fn)
70      writer = python_io.TFRecordWriter(fn)
71      for j in range(self._num_records):
72        record = self._record(j, i)
73        writer.write(record)
74      writer.close()
75    return filenames
76
77  def _createTextFiles(self):
78    filenames = []
79    for i in range(self._num_files):
80      fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
81      filenames.append(fn)
82      contents = []
83      for j in range(self._num_records):
84        contents.append(self._text_line(j, i))
85        if j + 1 != self._num_records or i == 0:
86          contents.append(b"\r\n")
87      contents = b"".join(contents)
88
89      with open(fn, "wb") as f:
90        f.write(contents)
91    return filenames
92
93  def _createFixedLengthRecordFiles(self):
94    filenames = []
95    for i in range(self._num_files):
96      fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
97      filenames.append(fn)
98      with open(fn, "wb") as f:
99        for j in range(self._num_records):
100          f.write(self._fixed_length_record(j, i))
101    return filenames
102
103  def _verifySimpleShardingOutput(self, dataset, record_fn):
104    next_element_fn = self._getNext(dataset)
105    with self.cached_session():
106      for f in range(self._shard_index, self._num_files, self._num_shards):
107        for r in range(self._num_records):
108          self.assertAllEqual(record_fn(r, f), self.evaluate(next_element_fn()))
109      with self.assertRaises(errors.OutOfRangeError):
110        self.evaluate(next_element_fn())
111
112  @test_util.run_in_graph_and_eager_modes
113  def testTFRecordDataset(self):
114    dataset = readers.TFRecordDataset(self._createTFRecordFiles())
115    dataset = input_ops.auto_shard_dataset(
116        dataset, self._num_shards, self._shard_index)
117
118    self._verifySimpleShardingOutput(dataset, self._record)
119
120  @test_util.run_in_graph_and_eager_modes
121  def testFlatMap(self):
122    dataset = dataset_ops.Dataset.from_tensor_slices(
123        self._createTFRecordFiles())
124    dataset = dataset.flat_map(readers.TFRecordDataset)
125    dataset = input_ops.auto_shard_dataset(
126        dataset, self._num_shards, self._shard_index)
127
128    self._verifySimpleShardingOutput(dataset, self._record)
129
130  @test_util.run_in_graph_and_eager_modes
131  def testInterleave(self):
132    dataset = dataset_ops.Dataset.from_tensor_slices(
133        self._createTFRecordFiles())
134    dataset = dataset.interleave(
135        readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
136    dataset = input_ops.auto_shard_dataset(
137        dataset, self._num_shards, self._shard_index)
138
139    # Since block_length == num records in each file, the output will still
140    # contain records in order of files.
141    self._verifySimpleShardingOutput(dataset, self._record)
142
143  @test_util.run_in_graph_and_eager_modes
144  def testListfiles(self):
145    filenames = self._createTFRecordFiles()
146    file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
147    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
148    dataset = dataset.flat_map(readers.TFRecordDataset)
149    dataset = input_ops.auto_shard_dataset(
150        dataset, self._num_shards, self._shard_index)
151
152    next_element_fn = self._getNext(dataset)
153    actual, expected = [], []
154    for f in range(self._shard_index, self._num_files, self._num_shards):
155      for r in range(self._num_records):
156        actual.append(self.evaluate(next_element_fn()))
157        expected.append(self._record(r, f))
158    with self.assertRaises(errors.OutOfRangeError):
159      self.evaluate(next_element_fn())
160    self.assertAllEqual(expected, actual)
161
162  @test_util.run_in_graph_and_eager_modes
163  def testComplexPipeline(self):
164    # Setup a complex input pipeline.
165    batch_size = 2
166    num_epochs = 5
167    dataset = dataset_ops.Dataset.from_tensor_slices(
168        self._createTFRecordFiles())
169    dataset = dataset.shuffle(buffer_size=self._num_files)
170    dataset = dataset.flat_map(readers.TFRecordDataset)
171    dataset = dataset.prefetch(buffer_size=batch_size)
172    dataset = dataset.shuffle(2 * self._num_files * self._num_records)
173    dataset = dataset.repeat(num_epochs)
174    dataset = dataset.map(lambda x: x)
175    dataset = dataset.batch(batch_size)
176    dataset = dataset.prefetch(buffer_size=None)
177
178    # Auto shard.
179    dataset = input_ops.auto_shard_dataset(
180        dataset, self._num_shards, self._shard_index)
181
182    # Verify output.
183    next_element_fn = self._getNext(dataset)
184    actual = []
185    num_iterations = (self._num_files * self._num_records * num_epochs) // (
186        self._num_shards * batch_size)
187    for _ in range(num_iterations):
188      actual.extend(self.evaluate(next_element_fn()))
189    with self.assertRaises(errors.OutOfRangeError):
190      self.evaluate(next_element_fn())
191
192    expected = []
193    for f in range(0, self._num_files, self._num_shards):
194      for r in range(self._num_records):
195        expected.append(self._record(r, f))
196    expected *= num_epochs
197
198    self.assertAllEqual(sorted(expected), sorted(actual))
199
200  @test_util.run_in_graph_and_eager_modes
201  def testZip(self):
202    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
203    dataset2 = readers.TextLineDataset(self._createTextFiles())
204
205    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
206    dataset = input_ops.auto_shard_dataset(
207        dataset, self._num_shards, self._shard_index)
208
209    record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
210    self._verifySimpleShardingOutput(dataset, record_fn)
211
212  @test_util.run_in_graph_and_eager_modes
213  def testConcat(self):
214    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
215    dataset2 = readers.TextLineDataset(self._createTextFiles())
216
217    dataset = dataset1.concatenate(dataset2)
218    dataset = input_ops.auto_shard_dataset(
219        dataset, self._num_shards, self._shard_index)
220
221    next_element_fn = self._getNext(dataset)
222    for f in range(self._shard_index, self._num_files, self._num_shards):
223      for r in range(self._num_records):
224        self.assertAllEqual(
225            self._record(r, f), self.evaluate(next_element_fn()))
226    for f in range(self._shard_index, self._num_files, self._num_shards):
227      for r in range(self._num_records):
228        self.assertAllEqual(
229            self._text_line(r, f), self.evaluate(next_element_fn()))
230    with self.assertRaises(errors.OutOfRangeError):
231      self.evaluate(next_element_fn())
232
233  @test_util.run_in_graph_and_eager_modes
234  def testTextLineReader(self):
235    dataset = readers.TextLineDataset(self._createTextFiles())
236
237    dataset = input_ops.auto_shard_dataset(
238        dataset, self._num_shards, self._shard_index)
239
240    self._verifySimpleShardingOutput(dataset, self._text_line)
241
242  @test_util.run_in_graph_and_eager_modes
243  def testTextLineReaderWithFlatMap(self):
244    dataset = readers.TextLineDataset(self._createTextFiles())
245    dataset = input_ops.auto_shard_dataset(
246        dataset, self._num_shards, self._shard_index)
247
248    self._verifySimpleShardingOutput(dataset, self._text_line)
249
250  @test_util.run_in_graph_and_eager_modes
251  def testFixedLengthReaderWithFlatMap(self):
252    dataset = readers.FixedLengthRecordDataset(
253        self._createFixedLengthRecordFiles(), self._record_bytes)
254    dataset = input_ops.auto_shard_dataset(
255        dataset, self._num_shards, self._shard_index)
256
257    self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
258
259
260# A dataset that creates two variant tensors.
261class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
262
263  def __init__(self, input_dataset):
264    self._input_dataset = input_dataset
265    temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
266        input_dataset._variant_tensor,
267        buffer_size=1,
268        **self._flat_structure)
269    variant_tensor = gen_dataset_ops.model_dataset(
270        temp_variant_tensor, **self._flat_structure)
271    super(_TestDataset, self).__init__(input_dataset, variant_tensor)
272
273
274class CloneDatasetTest(test.TestCase):
275
276  def _assert_datasets_equal(self, ds1, ds2):
277    # First lets assert the structure is the same.
278    self.assertTrue(
279        structure.are_compatible(ds1.element_spec, ds2.element_spec))
280
281    # Now create iterators on both and assert they produce the same values.
282    it1 = dataset_ops.make_initializable_iterator(ds1)
283    it2 = dataset_ops.make_initializable_iterator(ds2)
284
285    get_next1 = it1.get_next()
286    get_next2 = it2.get_next()
287
288    with self.cached_session():
289      self.evaluate([it1.initializer, it2.initializer])
290      val1, val2 = self.evaluate([get_next1, get_next2])
291      self.assertEqual(val1, val2)
292
293  @test_util.run_deprecated_v1
294  def testOnlySource(self):
295    ds = dataset_ops.Dataset.range(10)
296    cloned_ds = input_ops._clone_dataset(ds)
297    self._assert_datasets_equal(ds, cloned_ds)
298
299  @test_util.run_deprecated_v1
300  def testSimplePipeline(self):
301    ds = dataset_ops.Dataset.range(10).map(math_ops.square)
302    cloned_ds = input_ops._clone_dataset(ds)
303    self._assert_datasets_equal(ds, cloned_ds)
304
305  @test_util.run_deprecated_v1
306  def testConcat(self):
307    ds1 = dataset_ops.Dataset.range(10)
308    ds2 = dataset_ops.Dataset.range(10)
309    ds = ds1.concatenate(ds2)
310    cloned_ds = input_ops._clone_dataset(ds)
311    self._assert_datasets_equal(ds, cloned_ds)
312
313  @test_util.run_deprecated_v1
314  def testZip(self):
315    ds1 = dataset_ops.Dataset.range(10)
316    ds2 = dataset_ops.Dataset.range(10)
317    ds = dataset_ops.Dataset.zip((ds1, ds2))
318    cloned_ds = input_ops._clone_dataset(ds)
319    self._assert_datasets_equal(ds, cloned_ds)
320
321  @test_util.run_deprecated_v1
322  def testMultipleVariantTensors(self):
323    ds = dataset_ops.Dataset.range(10)
324    ds = _TestDataset(ds)
325    cloned_ds = input_ops._clone_dataset(ds)
326    self._assert_datasets_equal(ds, cloned_ds)
327
328
329if __name__ == "__main__":
330  test.main()
331