1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for input pipeline modifications for distribution strategies.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops import readers 25from tensorflow.python.data.util import structure 26from tensorflow.python.distribute import input_ops 27from tensorflow.python.eager import context 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import test_util 30from tensorflow.python.lib.io import python_io 31from tensorflow.python.ops import gen_dataset_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.platform import test 34from tensorflow.python.util import compat 35 36 37class AutoShardDatasetTest(test.TestCase): 38 39 def setUp(self): 40 super(AutoShardDatasetTest, self).setUp() 41 self._num_files = 10 42 self._num_records = 4 43 self._num_shards = 2 44 self._shard_index = 0 45 self._record_bytes = 10 46 47 def _getNext(self, dataset): 48 if context.executing_eagerly(): 49 iterator = iter(dataset) 50 return iterator._next_internal # pylint: disable=protected-access 51 else: 52 iterator = dataset_ops.make_one_shot_iterator(dataset) 53 get_next = iterator.get_next() 54 return lambda: get_next 55 56 def _record(self, r, f): 57 return compat.as_bytes("Record %d of file %d" % (r, f)) 58 59 def _text_line(self, r, f): 60 return compat.as_bytes("Text line %d of file %d" % (r, f)) 61 62 def _fixed_length_record(self, r, f): 63 return compat.as_bytes(str((r * f) % 10) * self._record_bytes) 64 65 def _createTFRecordFiles(self): 66 filenames = [] 67 for i in range(self._num_files): 68 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 69 filenames.append(fn) 70 writer = python_io.TFRecordWriter(fn) 71 for j in range(self._num_records): 72 record = self._record(j, i) 73 writer.write(record) 74 writer.close() 75 return filenames 76 77 def _createTextFiles(self): 78 filenames = [] 79 for i in range(self._num_files): 80 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 81 filenames.append(fn) 82 contents = [] 83 for j in range(self._num_records): 84 contents.append(self._text_line(j, i)) 85 if j + 1 != self._num_records or i == 0: 86 contents.append(b"\r\n") 87 contents = b"".join(contents) 88 89 with open(fn, "wb") as f: 90 f.write(contents) 91 return filenames 92 93 def _createFixedLengthRecordFiles(self): 94 filenames = [] 95 for i in range(self._num_files): 96 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 97 filenames.append(fn) 98 with open(fn, "wb") as f: 99 for j in range(self._num_records): 100 f.write(self._fixed_length_record(j, i)) 101 return filenames 102 103 def _verifySimpleShardingOutput(self, dataset, record_fn): 104 next_element_fn = self._getNext(dataset) 105 with self.cached_session(): 106 for f in range(self._shard_index, self._num_files, self._num_shards): 107 for r in range(self._num_records): 108 self.assertAllEqual(record_fn(r, f), self.evaluate(next_element_fn())) 109 with self.assertRaises(errors.OutOfRangeError): 110 self.evaluate(next_element_fn()) 111 112 @test_util.run_in_graph_and_eager_modes 113 def testTFRecordDataset(self): 114 dataset = readers.TFRecordDataset(self._createTFRecordFiles()) 115 dataset = input_ops.auto_shard_dataset( 116 dataset, self._num_shards, self._shard_index) 117 118 self._verifySimpleShardingOutput(dataset, self._record) 119 120 @test_util.run_in_graph_and_eager_modes 121 def testFlatMap(self): 122 dataset = dataset_ops.Dataset.from_tensor_slices( 123 self._createTFRecordFiles()) 124 dataset = dataset.flat_map(readers.TFRecordDataset) 125 dataset = input_ops.auto_shard_dataset( 126 dataset, self._num_shards, self._shard_index) 127 128 self._verifySimpleShardingOutput(dataset, self._record) 129 130 @test_util.run_in_graph_and_eager_modes 131 def testInterleave(self): 132 dataset = dataset_ops.Dataset.from_tensor_slices( 133 self._createTFRecordFiles()) 134 dataset = dataset.interleave( 135 readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) 136 dataset = input_ops.auto_shard_dataset( 137 dataset, self._num_shards, self._shard_index) 138 139 # Since block_length == num records in each file, the output will still 140 # contain records in order of files. 141 self._verifySimpleShardingOutput(dataset, self._record) 142 143 @test_util.run_in_graph_and_eager_modes 144 def testListfiles(self): 145 filenames = self._createTFRecordFiles() 146 file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt" 147 dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) 148 dataset = dataset.flat_map(readers.TFRecordDataset) 149 dataset = input_ops.auto_shard_dataset( 150 dataset, self._num_shards, self._shard_index) 151 152 next_element_fn = self._getNext(dataset) 153 actual, expected = [], [] 154 for f in range(self._shard_index, self._num_files, self._num_shards): 155 for r in range(self._num_records): 156 actual.append(self.evaluate(next_element_fn())) 157 expected.append(self._record(r, f)) 158 with self.assertRaises(errors.OutOfRangeError): 159 self.evaluate(next_element_fn()) 160 self.assertAllEqual(expected, actual) 161 162 @test_util.run_in_graph_and_eager_modes 163 def testComplexPipeline(self): 164 # Setup a complex input pipeline. 165 batch_size = 2 166 num_epochs = 5 167 dataset = dataset_ops.Dataset.from_tensor_slices( 168 self._createTFRecordFiles()) 169 dataset = dataset.shuffle(buffer_size=self._num_files) 170 dataset = dataset.flat_map(readers.TFRecordDataset) 171 dataset = dataset.prefetch(buffer_size=batch_size) 172 dataset = dataset.shuffle(2 * self._num_files * self._num_records) 173 dataset = dataset.repeat(num_epochs) 174 dataset = dataset.map(lambda x: x) 175 dataset = dataset.batch(batch_size) 176 dataset = dataset.prefetch(buffer_size=None) 177 178 # Auto shard. 179 dataset = input_ops.auto_shard_dataset( 180 dataset, self._num_shards, self._shard_index) 181 182 # Verify output. 183 next_element_fn = self._getNext(dataset) 184 actual = [] 185 num_iterations = (self._num_files * self._num_records * num_epochs) // ( 186 self._num_shards * batch_size) 187 for _ in range(num_iterations): 188 actual.extend(self.evaluate(next_element_fn())) 189 with self.assertRaises(errors.OutOfRangeError): 190 self.evaluate(next_element_fn()) 191 192 expected = [] 193 for f in range(0, self._num_files, self._num_shards): 194 for r in range(self._num_records): 195 expected.append(self._record(r, f)) 196 expected *= num_epochs 197 198 self.assertAllEqual(sorted(expected), sorted(actual)) 199 200 @test_util.run_in_graph_and_eager_modes 201 def testZip(self): 202 dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) 203 dataset2 = readers.TextLineDataset(self._createTextFiles()) 204 205 dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) 206 dataset = input_ops.auto_shard_dataset( 207 dataset, self._num_shards, self._shard_index) 208 209 record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) 210 self._verifySimpleShardingOutput(dataset, record_fn) 211 212 @test_util.run_in_graph_and_eager_modes 213 def testConcat(self): 214 dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) 215 dataset2 = readers.TextLineDataset(self._createTextFiles()) 216 217 dataset = dataset1.concatenate(dataset2) 218 dataset = input_ops.auto_shard_dataset( 219 dataset, self._num_shards, self._shard_index) 220 221 next_element_fn = self._getNext(dataset) 222 for f in range(self._shard_index, self._num_files, self._num_shards): 223 for r in range(self._num_records): 224 self.assertAllEqual( 225 self._record(r, f), self.evaluate(next_element_fn())) 226 for f in range(self._shard_index, self._num_files, self._num_shards): 227 for r in range(self._num_records): 228 self.assertAllEqual( 229 self._text_line(r, f), self.evaluate(next_element_fn())) 230 with self.assertRaises(errors.OutOfRangeError): 231 self.evaluate(next_element_fn()) 232 233 @test_util.run_in_graph_and_eager_modes 234 def testTextLineReader(self): 235 dataset = readers.TextLineDataset(self._createTextFiles()) 236 237 dataset = input_ops.auto_shard_dataset( 238 dataset, self._num_shards, self._shard_index) 239 240 self._verifySimpleShardingOutput(dataset, self._text_line) 241 242 @test_util.run_in_graph_and_eager_modes 243 def testTextLineReaderWithFlatMap(self): 244 dataset = readers.TextLineDataset(self._createTextFiles()) 245 dataset = input_ops.auto_shard_dataset( 246 dataset, self._num_shards, self._shard_index) 247 248 self._verifySimpleShardingOutput(dataset, self._text_line) 249 250 @test_util.run_in_graph_and_eager_modes 251 def testFixedLengthReaderWithFlatMap(self): 252 dataset = readers.FixedLengthRecordDataset( 253 self._createFixedLengthRecordFiles(), self._record_bytes) 254 dataset = input_ops.auto_shard_dataset( 255 dataset, self._num_shards, self._shard_index) 256 257 self._verifySimpleShardingOutput(dataset, self._fixed_length_record) 258 259 260# A dataset that creates two variant tensors. 261class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset): 262 263 def __init__(self, input_dataset): 264 self._input_dataset = input_dataset 265 temp_variant_tensor = gen_dataset_ops.prefetch_dataset( 266 input_dataset._variant_tensor, 267 buffer_size=1, 268 **self._flat_structure) 269 variant_tensor = gen_dataset_ops.model_dataset( 270 temp_variant_tensor, **self._flat_structure) 271 super(_TestDataset, self).__init__(input_dataset, variant_tensor) 272 273 274class CloneDatasetTest(test.TestCase): 275 276 def _assert_datasets_equal(self, ds1, ds2): 277 # First lets assert the structure is the same. 278 self.assertTrue( 279 structure.are_compatible(ds1.element_spec, ds2.element_spec)) 280 281 # Now create iterators on both and assert they produce the same values. 282 it1 = dataset_ops.make_initializable_iterator(ds1) 283 it2 = dataset_ops.make_initializable_iterator(ds2) 284 285 get_next1 = it1.get_next() 286 get_next2 = it2.get_next() 287 288 with self.cached_session(): 289 self.evaluate([it1.initializer, it2.initializer]) 290 val1, val2 = self.evaluate([get_next1, get_next2]) 291 self.assertEqual(val1, val2) 292 293 @test_util.run_deprecated_v1 294 def testOnlySource(self): 295 ds = dataset_ops.Dataset.range(10) 296 cloned_ds = input_ops._clone_dataset(ds) 297 self._assert_datasets_equal(ds, cloned_ds) 298 299 @test_util.run_deprecated_v1 300 def testSimplePipeline(self): 301 ds = dataset_ops.Dataset.range(10).map(math_ops.square) 302 cloned_ds = input_ops._clone_dataset(ds) 303 self._assert_datasets_equal(ds, cloned_ds) 304 305 @test_util.run_deprecated_v1 306 def testConcat(self): 307 ds1 = dataset_ops.Dataset.range(10) 308 ds2 = dataset_ops.Dataset.range(10) 309 ds = ds1.concatenate(ds2) 310 cloned_ds = input_ops._clone_dataset(ds) 311 self._assert_datasets_equal(ds, cloned_ds) 312 313 @test_util.run_deprecated_v1 314 def testZip(self): 315 ds1 = dataset_ops.Dataset.range(10) 316 ds2 = dataset_ops.Dataset.range(10) 317 ds = dataset_ops.Dataset.zip((ds1, ds2)) 318 cloned_ds = input_ops._clone_dataset(ds) 319 self._assert_datasets_equal(ds, cloned_ds) 320 321 @test_util.run_deprecated_v1 322 def testMultipleVariantTensors(self): 323 ds = dataset_ops.Dataset.range(10) 324 ds = _TestDataset(ds) 325 cloned_ds = input_ops._clone_dataset(ds) 326 self._assert_datasets_equal(ds, cloned_ds) 327 328 329if __name__ == "__main__": 330 test.main() 331