• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for reader Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.compat import compat
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.data.util import convert
23from tensorflow.python.data.util import structure
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_dataset_ops
29from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
30from tensorflow.python.util.tf_export import tf_export
31
32
33# TODO(b/64974358): Increase default buffer size to 256 MB.
34_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024  # 256 KB
35
36
37@tf_export("data.TextLineDataset", v1=[])
38class TextLineDatasetV2(dataset_ops.DatasetSource):
39  """A `Dataset` comprising lines from one or more text files."""
40
41  def __init__(self, filenames, compression_type=None, buffer_size=None):
42    """Creates a `TextLineDataset`.
43
44    Args:
45      filenames: A `tf.string` tensor containing one or more filenames.
46      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
47        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
48      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
49        to buffer. A value of 0 results in the default buffering values chosen
50        based on the compression type.
51    """
52    self._filenames = ops.convert_to_tensor(
53        filenames, dtype=dtypes.string, name="filenames")
54    self._compression_type = convert.optional_param_to_tensor(
55        "compression_type",
56        compression_type,
57        argument_default="",
58        argument_dtype=dtypes.string)
59    self._buffer_size = convert.optional_param_to_tensor(
60        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
61    variant_tensor = gen_dataset_ops.text_line_dataset(
62        self._filenames, self._compression_type, self._buffer_size)
63    super(TextLineDatasetV2, self).__init__(variant_tensor)
64
65  @property
66  def _element_structure(self):
67    return structure.TensorStructure(dtypes.string, [])
68
69
70@tf_export(v1=["data.TextLineDataset"])
71class TextLineDatasetV1(dataset_ops.DatasetV1Adapter):
72  """A `Dataset` comprising lines from one or more text files."""
73
74  def __init__(self, filenames, compression_type=None, buffer_size=None):
75    wrapped = TextLineDatasetV2(filenames, compression_type, buffer_size)
76    super(TextLineDatasetV1, self).__init__(wrapped)
77  __init__.__doc__ = TextLineDatasetV2.__init__.__doc__
78
79  @property
80  def _filenames(self):
81    return self._dataset._filenames  # pylint: disable=protected-access
82
83  @_filenames.setter
84  def _filenames(self, value):
85    self._dataset._filenames = value  # pylint: disable=protected-access
86
87
88class _TFRecordDataset(dataset_ops.DatasetSource):
89  """A `Dataset` comprising records from one or more TFRecord files."""
90
91  def __init__(self, filenames, compression_type=None, buffer_size=None):
92    """Creates a `TFRecordDataset`.
93
94    Args:
95      filenames: A `tf.string` tensor containing one or more filenames.
96      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
97        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
98      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
99        bytes in the read buffer. 0 means no buffering.
100    """
101    # Force the type to string even if filenames is an empty list.
102    self._filenames = ops.convert_to_tensor(
103        filenames, dtypes.string, name="filenames")
104    self._compression_type = convert.optional_param_to_tensor(
105        "compression_type",
106        compression_type,
107        argument_default="",
108        argument_dtype=dtypes.string)
109    self._buffer_size = convert.optional_param_to_tensor(
110        "buffer_size",
111        buffer_size,
112        argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
113    variant_tensor = gen_dataset_ops.tf_record_dataset(
114        self._filenames, self._compression_type, self._buffer_size)
115    super(_TFRecordDataset, self).__init__(variant_tensor)
116
117  @property
118  def _element_structure(self):
119    return structure.TensorStructure(dtypes.string, [])
120
121
122class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
123  """A `Dataset` that maps a function over its input and flattens the result."""
124
125  def __init__(self, input_dataset, map_func, cycle_length, block_length,
126               sloppy, buffer_output_elements, prefetch_input_elements):
127    """See `tf.data.experimental.parallel_interleave()` for details."""
128    self._input_dataset = input_dataset
129    self._map_func = dataset_ops.StructuredFunctionWrapper(
130        map_func, self._transformation_name(), dataset=input_dataset)
131    if not isinstance(self._map_func.output_structure,
132                      dataset_ops.DatasetStructure):
133      raise TypeError("`map_func` must return a `Dataset` object.")
134    self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
135    self._cycle_length = ops.convert_to_tensor(
136        cycle_length, dtype=dtypes.int64, name="cycle_length")
137    self._block_length = ops.convert_to_tensor(
138        block_length, dtype=dtypes.int64, name="block_length")
139    self._sloppy = ops.convert_to_tensor(
140        sloppy, dtype=dtypes.bool, name="sloppy")
141    self._buffer_output_elements = convert.optional_param_to_tensor(
142        "buffer_output_elements",
143        buffer_output_elements,
144        argument_default=2 * block_length)
145    self._prefetch_input_elements = convert.optional_param_to_tensor(
146        "prefetch_input_elements",
147        prefetch_input_elements,
148        argument_default=2 * cycle_length)
149    variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
150        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
151        self._map_func.function.captured_inputs,
152        self._cycle_length,
153        self._block_length,
154        self._sloppy,
155        self._buffer_output_elements,
156        self._prefetch_input_elements,
157        f=self._map_func.function,
158        **dataset_ops.flat_structure(self))
159    super(ParallelInterleaveDataset, self).__init__(input_dataset,
160                                                    variant_tensor)
161
162  def _functions(self):
163    return [self._map_func]
164
165  @property
166  def _element_structure(self):
167    return self._structure
168
169  def _transformation_name(self):
170    return "tf.data.experimental.parallel_interleave()"
171
172
173@tf_export("data.TFRecordDataset", v1=[])
174class TFRecordDatasetV2(dataset_ops.DatasetV2):
175  """A `Dataset` comprising records from one or more TFRecord files."""
176
177  def __init__(self, filenames, compression_type=None, buffer_size=None,
178               num_parallel_reads=None):
179    """Creates a `TFRecordDataset` to read one or more TFRecord files.
180
181    NOTE: The `num_parallel_reads` argument can be used to improve performance
182    when reading from a remote filesystem.
183
184    Args:
185      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
186        more filenames.
187      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
188        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
189      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
190        bytes in the read buffer. 0 means no buffering.
191      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
192        number of files to read in parallel. Defaults to reading files
193        sequentially.
194
195    Raises:
196      TypeError: If any argument does not have the expected type.
197      ValueError: If any argument does not have the expected shape.
198    """
199    if isinstance(filenames, dataset_ops.DatasetV2):
200      if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
201        raise TypeError(
202            "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
203      if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
204          tensor_shape.scalar()):
205        raise ValueError(
206            "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
207            "elements.")
208    else:
209      filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string)
210      filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
211      filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)
212
213    self._filenames = filenames
214    self._compression_type = compression_type
215    self._buffer_size = buffer_size
216    self._num_parallel_reads = num_parallel_reads
217
218    def read_one_file(filename):
219      return _TFRecordDataset(filename, compression_type, buffer_size)
220
221    if num_parallel_reads is None:
222      self._impl = filenames.flat_map(read_one_file)
223    else:
224      self._impl = ParallelInterleaveDataset(
225          filenames, read_one_file, cycle_length=num_parallel_reads,
226          block_length=1, sloppy=False, buffer_output_elements=None,
227          prefetch_input_elements=None)
228    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
229    super(TFRecordDatasetV2, self).__init__(variant_tensor)
230
231  def _clone(self,
232             filenames=None,
233             compression_type=None,
234             buffer_size=None,
235             num_parallel_reads=None):
236    return TFRecordDatasetV2(filenames or self._filenames,
237                             compression_type or self._compression_type,
238                             buffer_size or self._buffer_size,
239                             num_parallel_reads or self._num_parallel_reads)
240
241  def _inputs(self):
242    return self._impl._inputs()  # pylint: disable=protected-access
243
244  @property
245  def _element_structure(self):
246    return structure.TensorStructure(dtypes.string, [])
247
248
249@tf_export(v1=["data.TFRecordDataset"])
250class TFRecordDatasetV1(dataset_ops.DatasetV1Adapter):
251  """A `Dataset` comprising records from one or more TFRecord files."""
252
253  def __init__(self, filenames, compression_type=None, buffer_size=None,
254               num_parallel_reads=None):
255    wrapped = TFRecordDatasetV2(
256        filenames, compression_type, buffer_size, num_parallel_reads)
257    super(TFRecordDatasetV1, self).__init__(wrapped)
258  __init__.__doc__ = TFRecordDatasetV2.__init__.__doc__
259
260  def _clone(self,
261             filenames=None,
262             compression_type=None,
263             buffer_size=None,
264             num_parallel_reads=None):
265    # pylint: disable=protected-access
266    return TFRecordDatasetV1(
267        filenames or self._dataset._filenames,
268        compression_type or self._dataset._compression_type,
269        buffer_size or self._dataset._buffer_size,
270        num_parallel_reads or self._dataset._num_parallel_reads)
271
272  @property
273  def _filenames(self):
274    return self._dataset._filenames  # pylint: disable=protected-access
275
276  @_filenames.setter
277  def _filenames(self, value):
278    self._dataset._filenames = value  # pylint: disable=protected-access
279
280
281@tf_export("data.FixedLengthRecordDataset", v1=[])
282class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
283  """A `Dataset` of fixed-length records from one or more binary files."""
284
285  def __init__(self,
286               filenames,
287               record_bytes,
288               header_bytes=None,
289               footer_bytes=None,
290               buffer_size=None,
291               compression_type=None):
292    """Creates a `FixedLengthRecordDataset`.
293
294    Args:
295      filenames: A `tf.string` tensor containing one or more filenames.
296      record_bytes: A `tf.int64` scalar representing the number of bytes in
297        each record.
298      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
299        bytes to skip at the start of a file.
300      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
301        bytes to ignore at the end of a file.
302      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
303        bytes to buffer when reading.
304      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
305        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
306    """
307    self._filenames = ops.convert_to_tensor(
308        filenames, dtype=dtypes.string, name="filenames")
309    self._record_bytes = ops.convert_to_tensor(
310        record_bytes, dtype=dtypes.int64, name="record_bytes")
311
312    self._header_bytes = convert.optional_param_to_tensor(
313        "header_bytes", header_bytes)
314    self._footer_bytes = convert.optional_param_to_tensor(
315        "footer_bytes", footer_bytes)
316    self._buffer_size = convert.optional_param_to_tensor(
317        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
318    self._compression_type = convert.optional_param_to_tensor(
319        "compression_type",
320        compression_type,
321        argument_default="",
322        argument_dtype=dtypes.string)
323    if (self._compression_type is not None or
324        compat.forward_compatible(2018, 11, 30)):
325      variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
326          self._filenames, self._header_bytes, self._record_bytes,
327          self._footer_bytes, self._buffer_size, self._compression_type)
328    else:
329      variant_tensor = gen_dataset_ops.fixed_length_record_dataset(
330          self._filenames, self._header_bytes, self._record_bytes,
331          self._footer_bytes, self._buffer_size)
332    super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
333
334  @property
335  def _element_structure(self):
336    return structure.TensorStructure(dtypes.string, [])
337
338
339@tf_export(v1=["data.FixedLengthRecordDataset"])
340class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter):
341  """A `Dataset` of fixed-length records from one or more binary files."""
342
343  def __init__(self,
344               filenames,
345               record_bytes,
346               header_bytes=None,
347               footer_bytes=None,
348               buffer_size=None,
349               compression_type=None):
350    wrapped = FixedLengthRecordDatasetV2(
351        filenames, record_bytes, header_bytes, footer_bytes, buffer_size,
352        compression_type)
353    super(FixedLengthRecordDatasetV1, self).__init__(wrapped)
354  __init__.__doc__ = FixedLengthRecordDatasetV2.__init__.__doc__
355
356  @property
357  def _filenames(self):
358    return self._dataset._filenames  # pylint: disable=protected-access
359
360  @_filenames.setter
361  def _filenames(self, value):
362    self._dataset._filenames = value  # pylint: disable=protected-access
363
364
365# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
366# these aliases in place.
367FixedLengthRecordDataset = FixedLengthRecordDatasetV1
368TFRecordDataset = TFRecordDatasetV1
369TextLineDataset = TextLineDatasetV1
370