1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for reader Datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21 22from tensorflow.python import tf2 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.util import convert 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gen_dataset_ops 31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 32from tensorflow.python.util import nest 33from tensorflow.python.util.tf_export import tf_export 34 35_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB 36 37 38def _normalise_fspath(path): 39 """Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519).""" 40 return os.fspath(path) if isinstance(path, os.PathLike) else path 41 42 43def _create_or_validate_filenames_dataset(filenames): 44 """Creates (or validates) a dataset of filenames. 45 46 Args: 47 filenames: Either a list or dataset of filenames. If it is a list, it is 48 convert to a dataset. If it is a dataset, its type and shape is validated. 49 50 Returns: 51 A dataset of filenames. 52 """ 53 if isinstance(filenames, dataset_ops.DatasetV2): 54 if dataset_ops.get_legacy_output_types(filenames) != dtypes.string: 55 raise TypeError( 56 "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.") 57 if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with( 58 tensor_shape.TensorShape([])): 59 raise TypeError( 60 "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` " 61 "elements.") 62 else: 63 filenames = nest.map_structure(_normalise_fspath, filenames) 64 filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string) 65 if filenames.dtype != dtypes.string: 66 raise TypeError( 67 "`filenames` must be a `tf.Tensor` of dtype `tf.string` dtype." 68 " Got {}".format(filenames.dtype)) 69 filenames = array_ops.reshape(filenames, [-1], name="flat_filenames") 70 filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames) 71 72 return filenames 73 74 75def _create_dataset_reader(dataset_creator, filenames, num_parallel_reads=None): 76 """Creates a dataset that reads the given files using the given reader. 77 78 Args: 79 dataset_creator: A function that takes in a single file name and returns a 80 dataset. 81 filenames: A `tf.data.Dataset` containing one or more filenames. 82 num_parallel_reads: The number of parallel reads we should do. 83 84 Returns: 85 A `Dataset` that reads data from `filenames`. 86 """ 87 88 def read_one_file(filename): 89 filename = ops.convert_to_tensor(filename, dtypes.string, name="filename") 90 return dataset_creator(filename) 91 92 if num_parallel_reads is None: 93 return filenames.flat_map(read_one_file) 94 elif num_parallel_reads == dataset_ops.AUTOTUNE: 95 return filenames.interleave( 96 read_one_file, num_parallel_calls=num_parallel_reads) 97 else: 98 return ParallelInterleaveDataset( 99 filenames, 100 read_one_file, 101 cycle_length=num_parallel_reads, 102 block_length=1, 103 sloppy=False, 104 buffer_output_elements=None, 105 prefetch_input_elements=None) 106 107 108class _TextLineDataset(dataset_ops.DatasetSource): 109 """A `Dataset` comprising records from one or more text files.""" 110 111 def __init__(self, filenames, compression_type=None, buffer_size=None): 112 """Creates a `TextLineDataset`. 113 114 Args: 115 filenames: A `tf.string` tensor containing one or more filenames. 116 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 117 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 118 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 119 to buffer. A value of 0 results in the default buffering values chosen 120 based on the compression type. 121 """ 122 self._filenames = filenames 123 self._compression_type = convert.optional_param_to_tensor( 124 "compression_type", 125 compression_type, 126 argument_default="", 127 argument_dtype=dtypes.string) 128 self._buffer_size = convert.optional_param_to_tensor( 129 "buffer_size", 130 buffer_size, 131 argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) 132 variant_tensor = gen_dataset_ops.text_line_dataset(self._filenames, 133 self._compression_type, 134 self._buffer_size) 135 super(_TextLineDataset, self).__init__(variant_tensor) 136 137 @property 138 def element_spec(self): 139 return tensor_spec.TensorSpec([], dtypes.string) 140 141 142@tf_export("data.TextLineDataset", v1=[]) 143class TextLineDatasetV2(dataset_ops.DatasetSource): 144 """A `Dataset` comprising lines from one or more text files.""" 145 146 def __init__(self, 147 filenames, 148 compression_type=None, 149 buffer_size=None, 150 num_parallel_reads=None): 151 r"""Creates a `TextLineDataset`. 152 153 The elements of the dataset will be the lines of the input files, using 154 the newline character '\n' to denote line splits. The newline characters 155 will be stripped off of each element. 156 157 Args: 158 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or 159 more filenames. 160 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 161 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 162 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 163 to buffer. A value of 0 results in the default buffering values chosen 164 based on the compression type. 165 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 166 number of files to read in parallel. If greater than one, the records of 167 files read in parallel are outputted in an interleaved order. If your 168 input pipeline is I/O bottlenecked, consider setting this parameter to a 169 value greater than one to parallelize the I/O. If `None`, files will be 170 read sequentially. 171 """ 172 filenames = _create_or_validate_filenames_dataset(filenames) 173 self._filenames = filenames 174 self._compression_type = compression_type 175 self._buffer_size = buffer_size 176 177 def creator_fn(filename): 178 return _TextLineDataset(filename, compression_type, buffer_size) 179 180 self._impl = _create_dataset_reader(creator_fn, filenames, 181 num_parallel_reads) 182 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 183 184 super(TextLineDatasetV2, self).__init__(variant_tensor) 185 186 @property 187 def element_spec(self): 188 return tensor_spec.TensorSpec([], dtypes.string) 189 190 191@tf_export(v1=["data.TextLineDataset"]) 192class TextLineDatasetV1(dataset_ops.DatasetV1Adapter): 193 """A `Dataset` comprising lines from one or more text files.""" 194 195 def __init__(self, 196 filenames, 197 compression_type=None, 198 buffer_size=None, 199 num_parallel_reads=None): 200 wrapped = TextLineDatasetV2(filenames, compression_type, buffer_size, 201 num_parallel_reads) 202 super(TextLineDatasetV1, self).__init__(wrapped) 203 204 __init__.__doc__ = TextLineDatasetV2.__init__.__doc__ 205 206 @property 207 def _filenames(self): 208 return self._dataset._filenames # pylint: disable=protected-access 209 210 @_filenames.setter 211 def _filenames(self, value): 212 self._dataset._filenames = value # pylint: disable=protected-access 213 214 215class _TFRecordDataset(dataset_ops.DatasetSource): 216 """A `Dataset` comprising records from one or more TFRecord files.""" 217 218 def __init__(self, filenames, compression_type=None, buffer_size=None): 219 """Creates a `TFRecordDataset`. 220 221 Args: 222 filenames: A `tf.string` tensor containing one or more filenames. 223 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 224 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 225 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 226 bytes in the read buffer. 0 means no buffering. 227 """ 228 self._filenames = filenames 229 self._compression_type = convert.optional_param_to_tensor( 230 "compression_type", 231 compression_type, 232 argument_default="", 233 argument_dtype=dtypes.string) 234 self._buffer_size = convert.optional_param_to_tensor( 235 "buffer_size", 236 buffer_size, 237 argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) 238 variant_tensor = gen_dataset_ops.tf_record_dataset(self._filenames, 239 self._compression_type, 240 self._buffer_size) 241 super(_TFRecordDataset, self).__init__(variant_tensor) 242 243 @property 244 def element_spec(self): 245 return tensor_spec.TensorSpec([], dtypes.string) 246 247 248class ParallelInterleaveDataset(dataset_ops.UnaryDataset): 249 """A `Dataset` that maps a function over its input and flattens the result.""" 250 251 def __init__(self, input_dataset, map_func, cycle_length, block_length, 252 sloppy, buffer_output_elements, prefetch_input_elements): 253 """See `tf.data.experimental.parallel_interleave()` for details.""" 254 self._input_dataset = input_dataset 255 self._map_func = dataset_ops.StructuredFunctionWrapper( 256 map_func, self._transformation_name(), dataset=input_dataset) 257 if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec): 258 raise TypeError("`map_func` must return a `Dataset` object.") 259 self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access 260 self._cycle_length = ops.convert_to_tensor( 261 cycle_length, dtype=dtypes.int64, name="cycle_length") 262 self._block_length = ops.convert_to_tensor( 263 block_length, dtype=dtypes.int64, name="block_length") 264 self._buffer_output_elements = convert.optional_param_to_tensor( 265 "buffer_output_elements", 266 buffer_output_elements, 267 argument_default=2 * block_length) 268 self._prefetch_input_elements = convert.optional_param_to_tensor( 269 "prefetch_input_elements", 270 prefetch_input_elements, 271 argument_default=2 * cycle_length) 272 if sloppy is None: 273 self._deterministic = "default" 274 elif sloppy: 275 self._deterministic = "false" 276 else: 277 self._deterministic = "true" 278 variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2( 279 self._input_dataset._variant_tensor, # pylint: disable=protected-access 280 self._map_func.function.captured_inputs, 281 self._cycle_length, 282 self._block_length, 283 self._buffer_output_elements, 284 self._prefetch_input_elements, 285 f=self._map_func.function, 286 deterministic=self._deterministic, 287 **self._flat_structure) 288 super(ParallelInterleaveDataset, self).__init__(input_dataset, 289 variant_tensor) 290 291 def _functions(self): 292 return [self._map_func] 293 294 @property 295 def element_spec(self): 296 return self._element_spec 297 298 def _transformation_name(self): 299 return "tf.data.experimental.parallel_interleave()" 300 301 302@tf_export("data.TFRecordDataset", v1=[]) 303class TFRecordDatasetV2(dataset_ops.DatasetV2): 304 """A `Dataset` comprising records from one or more TFRecord files.""" 305 306 def __init__(self, 307 filenames, 308 compression_type=None, 309 buffer_size=None, 310 num_parallel_reads=None): 311 """Creates a `TFRecordDataset` to read one or more TFRecord files. 312 313 Each element of the dataset will contain a single TFRecord. 314 315 Args: 316 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or 317 more filenames. 318 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 319 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 320 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 321 bytes in the read buffer. If your input pipeline is I/O bottlenecked, 322 consider setting this parameter to a value 1-100 MBs. If `None`, a 323 sensible default for both local and remote file systems is used. 324 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 325 number of files to read in parallel. If greater than one, the records of 326 files read in parallel are outputted in an interleaved order. If your 327 input pipeline is I/O bottlenecked, consider setting this parameter to a 328 value greater than one to parallelize the I/O. If `None`, files will be 329 read sequentially. 330 331 Raises: 332 TypeError: If any argument does not have the expected type. 333 ValueError: If any argument does not have the expected shape. 334 """ 335 filenames = _create_or_validate_filenames_dataset(filenames) 336 337 self._filenames = filenames 338 self._compression_type = compression_type 339 self._buffer_size = buffer_size 340 self._num_parallel_reads = num_parallel_reads 341 342 def creator_fn(filename): 343 return _TFRecordDataset(filename, compression_type, buffer_size) 344 345 self._impl = _create_dataset_reader(creator_fn, filenames, 346 num_parallel_reads) 347 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 348 super(TFRecordDatasetV2, self).__init__(variant_tensor) 349 350 def _clone(self, 351 filenames=None, 352 compression_type=None, 353 buffer_size=None, 354 num_parallel_reads=None): 355 return TFRecordDatasetV2(filenames or self._filenames, compression_type or 356 self._compression_type, buffer_size or 357 self._buffer_size, num_parallel_reads or 358 self._num_parallel_reads) 359 360 def _inputs(self): 361 return self._impl._inputs() # pylint: disable=protected-access 362 363 @property 364 def element_spec(self): 365 return tensor_spec.TensorSpec([], dtypes.string) 366 367 368@tf_export(v1=["data.TFRecordDataset"]) 369class TFRecordDatasetV1(dataset_ops.DatasetV1Adapter): 370 """A `Dataset` comprising records from one or more TFRecord files.""" 371 372 def __init__(self, 373 filenames, 374 compression_type=None, 375 buffer_size=None, 376 num_parallel_reads=None): 377 wrapped = TFRecordDatasetV2(filenames, compression_type, buffer_size, 378 num_parallel_reads) 379 super(TFRecordDatasetV1, self).__init__(wrapped) 380 381 __init__.__doc__ = TFRecordDatasetV2.__init__.__doc__ 382 383 def _clone(self, 384 filenames=None, 385 compression_type=None, 386 buffer_size=None, 387 num_parallel_reads=None): 388 # pylint: disable=protected-access 389 return TFRecordDatasetV1( 390 filenames or self._dataset._filenames, compression_type or 391 self._dataset._compression_type, buffer_size or 392 self._dataset._buffer_size, num_parallel_reads or 393 self._dataset._num_parallel_reads) 394 395 @property 396 def _filenames(self): 397 return self._dataset._filenames # pylint: disable=protected-access 398 399 @_filenames.setter 400 def _filenames(self, value): 401 self._dataset._filenames = value # pylint: disable=protected-access 402 403 404class _FixedLengthRecordDataset(dataset_ops.DatasetSource): 405 """A `Dataset` of fixed-length records from one or more binary files.""" 406 407 def __init__(self, 408 filenames, 409 record_bytes, 410 header_bytes=None, 411 footer_bytes=None, 412 buffer_size=None, 413 compression_type=None): 414 """Creates a `FixedLengthRecordDataset`. 415 416 Args: 417 filenames: A `tf.string` tensor containing one or more filenames. 418 record_bytes: A `tf.int64` scalar representing the number of bytes in each 419 record. 420 header_bytes: (Optional.) A `tf.int64` scalar representing the number of 421 bytes to skip at the start of a file. 422 footer_bytes: (Optional.) A `tf.int64` scalar representing the number of 423 bytes to ignore at the end of a file. 424 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 425 bytes to buffer when reading. 426 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 427 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 428 """ 429 self._filenames = filenames 430 self._record_bytes = ops.convert_to_tensor( 431 record_bytes, dtype=dtypes.int64, name="record_bytes") 432 self._header_bytes = convert.optional_param_to_tensor( 433 "header_bytes", header_bytes) 434 self._footer_bytes = convert.optional_param_to_tensor( 435 "footer_bytes", footer_bytes) 436 self._buffer_size = convert.optional_param_to_tensor( 437 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 438 self._compression_type = convert.optional_param_to_tensor( 439 "compression_type", 440 compression_type, 441 argument_default="", 442 argument_dtype=dtypes.string) 443 variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2( 444 self._filenames, self._header_bytes, self._record_bytes, 445 self._footer_bytes, self._buffer_size, self._compression_type) 446 super(_FixedLengthRecordDataset, self).__init__(variant_tensor) 447 448 @property 449 def element_spec(self): 450 return tensor_spec.TensorSpec([], dtypes.string) 451 452 453@tf_export("data.FixedLengthRecordDataset", v1=[]) 454class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource): 455 """A `Dataset` of fixed-length records from one or more binary files.""" 456 457 def __init__(self, 458 filenames, 459 record_bytes, 460 header_bytes=None, 461 footer_bytes=None, 462 buffer_size=None, 463 compression_type=None, 464 num_parallel_reads=None): 465 """Creates a `FixedLengthRecordDataset`. 466 467 Args: 468 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or 469 more filenames. 470 record_bytes: A `tf.int64` scalar representing the number of bytes in each 471 record. 472 header_bytes: (Optional.) A `tf.int64` scalar representing the number of 473 bytes to skip at the start of a file. 474 footer_bytes: (Optional.) A `tf.int64` scalar representing the number of 475 bytes to ignore at the end of a file. 476 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 477 bytes to buffer when reading. 478 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 479 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 480 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 481 number of files to read in parallel. If greater than one, the records of 482 files read in parallel are outputted in an interleaved order. If your 483 input pipeline is I/O bottlenecked, consider setting this parameter to a 484 value greater than one to parallelize the I/O. If `None`, files will be 485 read sequentially. 486 """ 487 filenames = _create_or_validate_filenames_dataset(filenames) 488 489 self._filenames = filenames 490 self._record_bytes = record_bytes 491 self._header_bytes = header_bytes 492 self._footer_bytes = footer_bytes 493 self._buffer_size = buffer_size 494 self._compression_type = compression_type 495 496 def creator_fn(filename): 497 return _FixedLengthRecordDataset(filename, record_bytes, header_bytes, 498 footer_bytes, buffer_size, 499 compression_type) 500 501 self._impl = _create_dataset_reader(creator_fn, filenames, 502 num_parallel_reads) 503 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 504 super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor) 505 506 @property 507 def element_spec(self): 508 return tensor_spec.TensorSpec([], dtypes.string) 509 510 511@tf_export(v1=["data.FixedLengthRecordDataset"]) 512class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter): 513 """A `Dataset` of fixed-length records from one or more binary files.""" 514 515 def __init__(self, 516 filenames, 517 record_bytes, 518 header_bytes=None, 519 footer_bytes=None, 520 buffer_size=None, 521 compression_type=None, 522 num_parallel_reads=None): 523 wrapped = FixedLengthRecordDatasetV2(filenames, record_bytes, header_bytes, 524 footer_bytes, buffer_size, 525 compression_type, num_parallel_reads) 526 super(FixedLengthRecordDatasetV1, self).__init__(wrapped) 527 528 __init__.__doc__ = FixedLengthRecordDatasetV2.__init__.__doc__ 529 530 @property 531 def _filenames(self): 532 return self._dataset._filenames # pylint: disable=protected-access 533 534 @_filenames.setter 535 def _filenames(self, value): 536 self._dataset._filenames = value # pylint: disable=protected-access 537 538 539if tf2.enabled(): 540 FixedLengthRecordDataset = FixedLengthRecordDatasetV2 541 TFRecordDataset = TFRecordDatasetV2 542 TextLineDataset = TextLineDatasetV2 543else: 544 FixedLengthRecordDataset = FixedLengthRecordDatasetV1 545 TFRecordDataset = TFRecordDatasetV1 546 TextLineDataset = TextLineDatasetV1 547