• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for reader Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import csv
22import functools
23import gzip
24
25import numpy as np
26
27from tensorflow.python import tf2
28from tensorflow.python.data.experimental.ops import error_ops
29from tensorflow.python.data.experimental.ops import parsing_ops
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import options as options_lib
32from tensorflow.python.data.ops import readers as core_readers
33from tensorflow.python.data.util import convert
34from tensorflow.python.data.util import nest
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.lib.io import file_io
41from tensorflow.python.ops import gen_experimental_dataset_ops
42from tensorflow.python.ops import io_ops
43from tensorflow.python.platform import gfile
44from tensorflow.python.util.tf_export import tf_export
45
46_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
47                         dtypes.int64, dtypes.string)
48
49
50def _is_valid_int32(str_val):
51  try:
52    # Checks equality to prevent int32 overflow
53    return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
54        str_val)
55  except (ValueError, OverflowError):
56    return False
57
58
59def _is_valid_int64(str_val):
60  try:
61    dtypes.int64.as_numpy_dtype(str_val)
62    return True
63  except (ValueError, OverflowError):
64    return False
65
66
67def _is_valid_float(str_val, float_dtype):
68  try:
69    return float_dtype.as_numpy_dtype(str_val) < np.inf
70  except ValueError:
71    return False
72
73
74def _infer_type(str_val, na_value, prev_type):
75  """Given a string, infers its tensor type.
76
77  Infers the type of a value by picking the least 'permissive' type possible,
78  while still allowing the previous type inference for this column to be valid.
79
80  Args:
81    str_val: String value to infer the type of.
82    na_value: Additional string to recognize as a NA/NaN CSV value.
83    prev_type: Type previously inferred based on values of this column that
84      we've seen up till now.
85  Returns:
86    Inferred dtype.
87  """
88  if str_val in ("", na_value):
89    # If the field is null, it gives no extra information about its type
90    return prev_type
91
92  type_list = [
93      dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
94  ]  # list of types to try, ordered from least permissive to most
95
96  type_functions = [
97      _is_valid_int32,
98      _is_valid_int64,
99      lambda str_val: _is_valid_float(str_val, dtypes.float32),
100      lambda str_val: _is_valid_float(str_val, dtypes.float64),
101      lambda str_val: True,
102  ]  # Corresponding list of validation functions
103
104  for i in range(len(type_list)):
105    validation_fn = type_functions[i]
106    if validation_fn(str_val) and (prev_type is None or
107                                   prev_type in type_list[:i + 1]):
108      return type_list[i]
109
110
111def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
112                  file_io_fn):
113  """Generator that yields rows of CSV file(s) in order."""
114  for fn in filenames:
115    with file_io_fn(fn) as f:
116      rdr = csv.reader(
117          f,
118          delimiter=field_delim,
119          quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
120      if header:
121        next(rdr)  # Skip header lines
122
123      for csv_row in rdr:
124        if len(csv_row) != num_cols:
125          raise ValueError(
126              "Problem inferring types: CSV row has different number of fields "
127              "than expected.")
128        yield csv_row
129
130
131def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
132                           na_value, header, num_rows_for_inference,
133                           select_columns, file_io_fn):
134  """Infers column types from the first N valid CSV records of files."""
135  if select_columns is None:
136    select_columns = range(num_cols)
137  inferred_types = [None] * len(select_columns)
138
139  for i, csv_row in enumerate(
140      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
141                    file_io_fn)):
142    if num_rows_for_inference is not None and i >= num_rows_for_inference:
143      break
144
145    for j, col_index in enumerate(select_columns):
146      inferred_types[j] = _infer_type(csv_row[col_index], na_value,
147                                      inferred_types[j])
148
149  # Replace None's with a default type
150  inferred_types = [t or dtypes.string for t in inferred_types]
151  # Default to 0 or '' for null values
152  return [
153      constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
154      for t in inferred_types
155  ]
156
157
158def _infer_column_names(filenames, field_delim, use_quote_delim, file_io_fn):
159  """Infers column names from first rows of files."""
160  csv_kwargs = {
161      "delimiter": field_delim,
162      "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
163  }
164  with file_io_fn(filenames[0]) as f:
165    try:
166      column_names = next(csv.reader(f, **csv_kwargs))
167    except StopIteration:
168      raise ValueError(("Received StopIteration when reading the header line "
169                        "of %s.  Empty file?") % filenames[0])
170
171  for name in filenames[1:]:
172    with file_io_fn(name) as f:
173      try:
174        if next(csv.reader(f, **csv_kwargs)) != column_names:
175          raise ValueError(
176              "Files have different column names in the header row.")
177      except StopIteration:
178        raise ValueError(("Received StopIteration when reading the header line "
179                          "of %s.  Empty file?") % filenames[0])
180  return column_names
181
182
183def _get_sorted_col_indices(select_columns, column_names):
184  """Transforms select_columns argument into sorted column indices."""
185  names_to_indices = {n: i for i, n in enumerate(column_names)}
186  num_cols = len(column_names)
187
188  results = []
189  for v in select_columns:
190    # If value is already an int, check if it's valid.
191    if isinstance(v, int):
192      if v < 0 or v >= num_cols:
193        raise ValueError(
194            "Column index %d specified in select_columns out of valid range." %
195            v)
196      results.append(v)
197    # Otherwise, check that it's a valid column name and convert to the
198    # the relevant column index.
199    elif v not in names_to_indices:
200      raise ValueError(
201          "Value '%s' specified in select_columns not a valid column index or "
202          "name." % v)
203    else:
204      results.append(names_to_indices[v])
205
206  # Sort and ensure there are no duplicates
207  results = sorted(set(results))
208  if len(results) != len(select_columns):
209    raise ValueError("select_columns contains duplicate columns")
210  return results
211
212
213def _maybe_shuffle_and_repeat(
214    dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
215  """Optionally shuffle and repeat dataset, as requested."""
216  if shuffle:
217    dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
218  if num_epochs != 1:
219    dataset = dataset.repeat(num_epochs)
220  return dataset
221
222
223def make_tf_record_dataset(file_pattern,
224                           batch_size,
225                           parser_fn=None,
226                           num_epochs=None,
227                           shuffle=True,
228                           shuffle_buffer_size=None,
229                           shuffle_seed=None,
230                           prefetch_buffer_size=None,
231                           num_parallel_reads=None,
232                           num_parallel_parser_calls=None,
233                           drop_final_batch=False):
234  """Reads and optionally parses TFRecord files into a dataset.
235
236  Provides common functionality such as batching, optional parsing, shuffling,
237  and performant defaults.
238
239  Args:
240    file_pattern: List of files or patterns of TFRecord file paths.
241      See `tf.io.gfile.glob` for pattern rules.
242    batch_size: An int representing the number of records to combine
243      in a single batch.
244    parser_fn: (Optional.) A function accepting string input to parse
245      and process the record contents. This function must map records
246      to components of a fixed shape, so they may be batched. By
247      default, uses the record contents unmodified.
248    num_epochs: (Optional.) An int specifying the number of times this
249      dataset is repeated.  If None (the default), cycles through the
250      dataset forever.
251    shuffle: (Optional.) A bool that indicates whether the input
252      should be shuffled. Defaults to `True`.
253    shuffle_buffer_size: (Optional.) Buffer size to use for
254      shuffling. A large buffer size ensures better shuffling, but
255      increases memory usage and startup time.
256    shuffle_seed: (Optional.) Randomization seed to use for shuffling.
257    prefetch_buffer_size: (Optional.) An int specifying the number of
258      feature batches to prefetch for performance improvement.
259      Defaults to auto-tune. Set to 0 to disable prefetching.
260    num_parallel_reads: (Optional.) Number of threads used to read
261      records from files. By default or if set to a value >1, the
262      results will be interleaved. Defaults to `24`.
263    num_parallel_parser_calls: (Optional.) Number of parallel
264      records to parse in parallel. Defaults to `batch_size`.
265    drop_final_batch: (Optional.) Whether the last batch should be
266      dropped in case its size is smaller than `batch_size`; the
267      default behavior is not to drop the smaller batch.
268
269  Returns:
270    A dataset, where each element matches the output of `parser_fn`
271    except it will have an additional leading `batch-size` dimension,
272    or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
273    unspecified.
274  """
275  if num_parallel_reads is None:
276    # NOTE: We considered auto-tuning this value, but there is a concern
277    # that this affects the mixing of records from different files, which
278    # could affect training convergence/accuracy, so we are defaulting to
279    # a constant for now.
280    num_parallel_reads = 24
281
282  if num_parallel_parser_calls is None:
283    # TODO(josh11b): if num_parallel_parser_calls is None, use some function
284    # of num cores instead of `batch_size`.
285    num_parallel_parser_calls = batch_size
286
287  if prefetch_buffer_size is None:
288    prefetch_buffer_size = dataset_ops.AUTOTUNE
289
290  files = dataset_ops.Dataset.list_files(
291      file_pattern, shuffle=shuffle, seed=shuffle_seed)
292
293  dataset = core_readers.TFRecordDataset(
294      files, num_parallel_reads=num_parallel_reads)
295
296  if shuffle_buffer_size is None:
297    # TODO(josh11b): Auto-tune this value when not specified
298    shuffle_buffer_size = 10000
299  dataset = _maybe_shuffle_and_repeat(
300      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
301
302  # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
303  # improve the shape inference, because it makes the batch dimension static.
304  # It is safe to do this because in that case we are repeating the input
305  # indefinitely, and all batches will be full-sized.
306  drop_final_batch = drop_final_batch or num_epochs is None
307
308  if parser_fn is None:
309    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
310  else:
311    dataset = dataset.map(
312        parser_fn, num_parallel_calls=num_parallel_parser_calls)
313    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
314
315  if prefetch_buffer_size == 0:
316    return dataset
317  else:
318    return dataset.prefetch(buffer_size=prefetch_buffer_size)
319
320
321@tf_export("data.experimental.make_csv_dataset", v1=[])
322def make_csv_dataset_v2(
323    file_pattern,
324    batch_size,
325    column_names=None,
326    column_defaults=None,
327    label_name=None,
328    select_columns=None,
329    field_delim=",",
330    use_quote_delim=True,
331    na_value="",
332    header=True,
333    num_epochs=None,  # TODO(aaudibert): Change default to 1 when graduating.
334    shuffle=True,
335    shuffle_buffer_size=10000,
336    shuffle_seed=None,
337    prefetch_buffer_size=None,
338    num_parallel_reads=None,
339    sloppy=False,
340    num_rows_for_inference=100,
341    compression_type=None,
342    ignore_errors=False,
343):
344  """Reads CSV files into a dataset.
345
346  Reads CSV files into a dataset, where each element of the dataset is a
347  (features, labels) tuple that corresponds to a batch of CSV rows. The features
348  dictionary maps feature column names to `Tensor`s containing the corresponding
349  feature data, and labels is a `Tensor` containing the batch's label data.
350
351  By default, the first rows of the CSV files are expected to be headers listing
352  the column names. If the first rows are not headers, set `header=False` and
353  provide the column names with the `column_names` argument.
354
355  By default, the dataset is repeated indefinitely, reshuffling the order each
356  time. This behavior can be modified by setting the `num_epochs` and `shuffle`
357  arguments.
358
359  For example, suppose you have a CSV file containing
360
361  | Feature_A | Feature_B |
362  | --------- | --------- |
363  | 1         | "a"       |
364  | 2         | "b"       |
365  | 3         | "c"       |
366  | 4         | "d"       |
367
368  ```
369  # No label column specified
370  dataset = tf.data.experimental.make_csv_dataset(filename, batch_size=2)
371  iterator = ds.as_numpy_iterator()
372  print(dict(next(iterator)))
373  # prints a dictionary of batched features:
374  # OrderedDict([('Feature_A', array([1, 4], dtype=int32)),
375  #              ('Feature_B', array([b'a', b'd'], dtype=object))])
376  ```
377
378  ```
379  # Set Feature_B as label column
380  dataset = tf.data.experimental.make_csv_dataset(
381      filename, batch_size=2, label_name="Feature_B")
382  iterator = ds.as_numpy_iterator()
383  print(next(iterator))
384  # prints (features, labels) tuple:
385  # (OrderedDict([('Feature_A', array([1, 2], dtype=int32))]),
386  #  array([b'a', b'b'], dtype=object))
387  ```
388
389  See the
390  [Load CSV data guide](https://www.tensorflow.org/tutorials/load_data/csv) for
391  more examples of using `make_csv_dataset` to read CSV data.
392
393  Args:
394    file_pattern: List of files or patterns of file paths containing CSV
395      records. See `tf.io.gfile.glob` for pattern rules.
396    batch_size: An int representing the number of records to combine
397      in a single batch.
398    column_names: An optional list of strings that corresponds to the CSV
399      columns, in order. One per column of the input record. If this is not
400      provided, infers the column names from the first row of the records.
401      These names will be the keys of the features dict of each dataset element.
402    column_defaults: A optional list of default values for the CSV fields. One
403      item per selected column of the input record. Each item in the list is
404      either a valid CSV dtype (float32, float64, int32, int64, or string), or a
405      `Tensor` with one of the aforementioned types. The tensor can either be
406      a scalar default value (if the column is optional), or an empty tensor (if
407      the column is required). If a dtype is provided instead of a tensor, the
408      column is also treated as required. If this list is not provided, tries
409      to infer types based on reading the first num_rows_for_inference rows of
410      files specified, and assumes all columns are optional, defaulting to `0`
411      for numeric values and `""` for string values. If both this and
412      `select_columns` are specified, these must have the same lengths, and
413      `column_defaults` is assumed to be sorted in order of increasing column
414      index.
415    label_name: A optional string corresponding to the label column. If
416      provided, the data for this column is returned as a separate `Tensor` from
417      the features dictionary, so that the dataset complies with the format
418      expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
419      function.
420    select_columns: An optional list of integer indices or string column
421      names, that specifies a subset of columns of CSV data to select. If
422      column names are provided, these must correspond to names provided in
423      `column_names` or inferred from the file header lines. When this argument
424      is specified, only a subset of CSV columns will be parsed and returned,
425      corresponding to the columns specified. Using this results in faster
426      parsing and lower memory usage. If both this and `column_defaults` are
427      specified, these must have the same lengths, and `column_defaults` is
428      assumed to be sorted in order of increasing column index.
429    field_delim: An optional `string`. Defaults to `","`. Char delimiter to
430      separate fields in a record.
431    use_quote_delim: An optional bool. Defaults to `True`. If false, treats
432      double quotation marks as regular characters inside of the string fields.
433    na_value: Additional string to recognize as NA/NaN.
434    header: A bool that indicates whether the first rows of provided CSV files
435      correspond to header lines with column names, and should not be included
436      in the data.
437    num_epochs: An int specifying the number of times this dataset is repeated.
438      If None, cycles through the dataset forever.
439    shuffle: A bool that indicates whether the input should be shuffled.
440    shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
441      ensures better shuffling, but increases memory usage and startup time.
442    shuffle_seed: Randomization seed to use for shuffling.
443    prefetch_buffer_size: An int specifying the number of feature
444      batches to prefetch for performance improvement. Recommended value is the
445      number of batches consumed per training step. Defaults to auto-tune.
446    num_parallel_reads: Number of threads used to read CSV records from files.
447      If >1, the results will be interleaved. Defaults to `1`.
448    sloppy: If `True`, reading performance will be improved at
449      the cost of non-deterministic ordering. If `False`, the order of elements
450      produced is deterministic prior to shuffling (elements are still
451      randomized if `shuffle=True`. Note that if the seed is set, then order
452      of elements after shuffling is deterministic). Defaults to `False`.
453    num_rows_for_inference: Number of rows of a file to use for type inference
454      if record_defaults is not provided. If None, reads all the rows of all
455      the files. Defaults to 100.
456    compression_type: (Optional.) A `tf.string` scalar evaluating to one of
457      `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
458    ignore_errors: (Optional.) If `True`, ignores errors with CSV file parsing,
459      such as malformed data or empty lines, and moves on to the next valid
460      CSV record. Otherwise, the dataset raises an error and stops processing
461      when encountering any invalid records. Defaults to `False`.
462
463  Returns:
464    A dataset, where each element is a (features, labels) tuple that corresponds
465    to a batch of `batch_size` CSV rows. The features dictionary maps feature
466    column names to `Tensor`s containing the corresponding column data, and
467    labels is a `Tensor` containing the column data for the label column
468    specified by `label_name`.
469
470  Raises:
471    ValueError: If any of the arguments is malformed.
472  """
473  if num_parallel_reads is None:
474    num_parallel_reads = 1
475
476  if prefetch_buffer_size is None:
477    prefetch_buffer_size = dataset_ops.AUTOTUNE
478
479  # Create dataset of all matching filenames
480  filenames = _get_file_names(file_pattern, False)
481  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
482  if shuffle:
483    dataset = dataset.shuffle(len(filenames), shuffle_seed)
484
485  # Clean arguments; figure out column names and defaults
486  if column_names is None or column_defaults is None:
487    # Find out which io function to open the file
488    file_io_fn = lambda filename: file_io.FileIO(filename, "r")
489    if compression_type is not None:
490      compression_type_value = tensor_util.constant_value(compression_type)
491      if compression_type_value is None:
492        raise ValueError("Received unknown compression_type")
493      if compression_type_value == "GZIP":
494        file_io_fn = lambda filename: gzip.open(filename, "rt")
495      elif compression_type_value == "ZLIB":
496        raise ValueError(
497            "compression_type (%s) is not supported for probing columns" %
498            compression_type)
499      elif compression_type_value != "":
500        raise ValueError("compression_type (%s) is not supported" %
501                         compression_type)
502  if column_names is None:
503    if not header:
504      raise ValueError("Cannot infer column names without a header line.")
505    # If column names are not provided, infer from the header lines
506    column_names = _infer_column_names(filenames, field_delim, use_quote_delim,
507                                       file_io_fn)
508  if len(column_names) != len(set(column_names)):
509    raise ValueError("Cannot have duplicate column names.")
510
511  if select_columns is not None:
512    select_columns = _get_sorted_col_indices(select_columns, column_names)
513
514  if column_defaults is not None:
515    column_defaults = [
516        constant_op.constant([], dtype=x)
517        if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x
518        for x in column_defaults
519    ]
520  else:
521    # If column defaults are not provided, infer from records at graph
522    # construction time
523    column_defaults = _infer_column_defaults(filenames, len(column_names),
524                                             field_delim, use_quote_delim,
525                                             na_value, header,
526                                             num_rows_for_inference,
527                                             select_columns, file_io_fn)
528
529  if select_columns is not None and len(column_defaults) != len(select_columns):
530    raise ValueError(
531        "If specified, column_defaults and select_columns must have same "
532        "length."
533    )
534  if select_columns is not None and len(column_names) > len(select_columns):
535    # Pick the relevant subset of column names
536    column_names = [column_names[i] for i in select_columns]
537
538  if label_name is not None and label_name not in column_names:
539    raise ValueError("`label_name` provided must be one of the columns.")
540
541  def filename_to_dataset(filename):
542    dataset = CsvDataset(
543        filename,
544        record_defaults=column_defaults,
545        field_delim=field_delim,
546        use_quote_delim=use_quote_delim,
547        na_value=na_value,
548        select_cols=select_columns,
549        header=header,
550        compression_type=compression_type
551    )
552    if ignore_errors:
553      dataset = dataset.apply(error_ops.ignore_errors())
554    return dataset
555
556  def map_fn(*columns):
557    """Organizes columns into a features dictionary.
558
559    Args:
560      *columns: list of `Tensor`s corresponding to one csv record.
561    Returns:
562      An OrderedDict of feature names to values for that particular record. If
563      label_name is provided, extracts the label feature to be returned as the
564      second element of the tuple.
565    """
566    features = collections.OrderedDict(zip(column_names, columns))
567    if label_name is not None:
568      label = features.pop(label_name)
569      return features, label
570    return features
571
572  if num_parallel_reads == dataset_ops.AUTOTUNE:
573    dataset = dataset.interleave(
574        filename_to_dataset, num_parallel_calls=num_parallel_reads)
575    options = options_lib.Options()
576    options.deterministic = not sloppy
577    dataset = dataset.with_options(options)
578  else:
579    # Read files sequentially (if num_parallel_reads=1) or in parallel
580    def apply_fn(dataset):
581      return core_readers.ParallelInterleaveDataset(
582          dataset,
583          filename_to_dataset,
584          cycle_length=num_parallel_reads,
585          block_length=1,
586          sloppy=sloppy,
587          buffer_output_elements=None,
588          prefetch_input_elements=None)
589
590    dataset = dataset.apply(apply_fn)
591
592  dataset = _maybe_shuffle_and_repeat(
593      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
594
595  # Apply batch before map for perf, because map has high overhead relative
596  # to the size of the computation in each map.
597  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
598  # improve the shape inference, because it makes the batch dimension static.
599  # It is safe to do this because in that case we are repeating the input
600  # indefinitely, and all batches will be full-sized.
601  dataset = dataset.batch(batch_size=batch_size,
602                          drop_remainder=num_epochs is None)
603  dataset = dataset_ops.MapDataset(
604      dataset, map_fn, use_inter_op_parallelism=False)
605  dataset = dataset.prefetch(prefetch_buffer_size)
606
607  return dataset
608
609
610@tf_export(v1=["data.experimental.make_csv_dataset"])
611def make_csv_dataset_v1(
612    file_pattern,
613    batch_size,
614    column_names=None,
615    column_defaults=None,
616    label_name=None,
617    select_columns=None,
618    field_delim=",",
619    use_quote_delim=True,
620    na_value="",
621    header=True,
622    num_epochs=None,
623    shuffle=True,
624    shuffle_buffer_size=10000,
625    shuffle_seed=None,
626    prefetch_buffer_size=None,
627    num_parallel_reads=None,
628    sloppy=False,
629    num_rows_for_inference=100,
630    compression_type=None,
631    ignore_errors=False,
632):  # pylint: disable=missing-docstring
633  return dataset_ops.DatasetV1Adapter(make_csv_dataset_v2(
634      file_pattern, batch_size, column_names, column_defaults, label_name,
635      select_columns, field_delim, use_quote_delim, na_value, header,
636      num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
637      prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference,
638      compression_type, ignore_errors))
639make_csv_dataset_v1.__doc__ = make_csv_dataset_v2.__doc__
640
641
642_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024  # 4 MB
643
644
645@tf_export("data.experimental.CsvDataset", v1=[])
646class CsvDatasetV2(dataset_ops.DatasetSource):
647  r"""A Dataset comprising lines from one or more CSV files.
648
649  The `tf.data.experimental.CsvDataset` class provides a minimal CSV Dataset
650  interface. There is also a richer `tf.data.experimental.make_csv_dataset`
651  function which provides additional convenience features such as column header
652  parsing, column type-inference, automatic shuffling, and file interleaving.
653
654  The elements of this dataset correspond to records from the file(s).
655  RFC 4180 format is expected for CSV files
656  (https://tools.ietf.org/html/rfc4180)
657  Note that we allow leading and trailing spaces for int or float fields.
658
659  For example, suppose we have a file 'my_file0.csv' with four CSV columns of
660  different data types:
661
662  >>> with open('/tmp/my_file0.csv', 'w') as f:
663  ...   f.write('abcdefg,4.28E10,5.55E6,12\n')
664  ...   f.write('hijklmn,-5.3E14,,2\n')
665
666  We can construct a CsvDataset from it as follows:
667
668  >>> dataset = tf.data.experimental.CsvDataset(
669  ...   "/tmp/my_file0.csv",
670  ...   [tf.float32,  # Required field, use dtype or empty tensor
671  ...    tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
672  ...    tf.int32,  # Required field, use dtype or empty tensor
673  ...   ],
674  ...   select_cols=[1,2,3]  # Only parse last three columns
675  ... )
676
677  The expected output of its iterations is:
678
679  >>> for element in dataset.as_numpy_iterator():
680  ...   print(element)
681  (4.28e10, 5.55e6, 12)
682  (-5.3e14, 0.0, 2)
683
684  See
685  https://www.tensorflow.org/tutorials/load_data/csv#tfdataexperimentalcsvdataset
686  for more in-depth example usage.
687  """
688
689  def __init__(self,
690               filenames,
691               record_defaults,
692               compression_type=None,
693               buffer_size=None,
694               header=False,
695               field_delim=",",
696               use_quote_delim=True,
697               na_value="",
698               select_cols=None,
699               exclude_cols=None):
700    """Creates a `CsvDataset` by reading and decoding CSV files.
701
702    Args:
703      filenames: A `tf.string` tensor containing one or more filenames.
704      record_defaults: A list of default values for the CSV fields. Each item in
705        the list is either a valid CSV `DType` (float32, float64, int32, int64,
706        string), or a `Tensor` object with one of the above types. One per
707        column of CSV data, with either a scalar `Tensor` default value for the
708        column if it is optional, or `DType` or empty `Tensor` if required. If
709        both this and `select_columns` are specified, these must have the same
710        lengths, and `column_defaults` is assumed to be sorted in order of
711        increasing column index. If both this and 'exclude_cols' are specified,
712        the sum of lengths of record_defaults and exclude_cols should equal
713        the total number of columns in the CSV file.
714      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
715        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
716        compression.
717      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
718        to buffer while reading files. Defaults to 4MB.
719      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
720        have header line(s) that should be skipped when parsing. Defaults to
721        `False`.
722      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
723        character that separates fields in a record. Defaults to `","`.
724      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
725        double quotation marks as regular characters inside of string fields
726        (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
727      na_value: (Optional.) A `tf.string` scalar indicating a value that will
728        be treated as NA/NaN.
729      select_cols: (Optional.) A sorted list of column indices to select from
730        the input data. If specified, only this subset of columns will be
731        parsed. Defaults to parsing all columns. At most one of `select_cols`
732        and `exclude_cols` can be specified.
733      exclude_cols: (Optional.) A sorted list of column indices to exclude from
734        the input data. If specified, only the complement of this set of column
735        will be parsed. Defaults to parsing all columns. At most one of
736        `select_cols` and `exclude_cols` can be specified.
737
738    Raises:
739       InvalidArgumentError: If exclude_cols is not None and
740           len(exclude_cols) + len(record_defaults) does not match the total
741           number of columns in the file(s)
742
743
744    """
745    self._filenames = ops.convert_to_tensor(
746        filenames, dtype=dtypes.string, name="filenames")
747    self._compression_type = convert.optional_param_to_tensor(
748        "compression_type",
749        compression_type,
750        argument_default="",
751        argument_dtype=dtypes.string)
752    record_defaults = [
753        constant_op.constant([], dtype=x)
754        if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x
755        for x in record_defaults
756    ]
757    self._record_defaults = ops.convert_n_to_tensor(
758        record_defaults, name="record_defaults")
759    self._buffer_size = convert.optional_param_to_tensor(
760        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
761    self._header = ops.convert_to_tensor(
762        header, dtype=dtypes.bool, name="header")
763    self._field_delim = ops.convert_to_tensor(
764        field_delim, dtype=dtypes.string, name="field_delim")
765    self._use_quote_delim = ops.convert_to_tensor(
766        use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
767    self._na_value = ops.convert_to_tensor(
768        na_value, dtype=dtypes.string, name="na_value")
769    self._select_cols = convert.optional_param_to_tensor(
770        "select_cols",
771        select_cols,
772        argument_default=[],
773        argument_dtype=dtypes.int64,
774    )
775    self._exclude_cols = convert.optional_param_to_tensor(
776        "exclude_cols",
777        exclude_cols,
778        argument_default=[],
779        argument_dtype=dtypes.int64,
780    )
781    self._element_spec = tuple(
782        tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
783    variant_tensor = gen_experimental_dataset_ops.csv_dataset_v2(
784        filenames=self._filenames,
785        record_defaults=self._record_defaults,
786        buffer_size=self._buffer_size,
787        header=self._header,
788        output_shapes=self._flat_shapes,
789        field_delim=self._field_delim,
790        use_quote_delim=self._use_quote_delim,
791        na_value=self._na_value,
792        select_cols=self._select_cols,
793        exclude_cols=self._exclude_cols,
794        compression_type=self._compression_type)
795    super(CsvDatasetV2, self).__init__(variant_tensor)
796
797  @property
798  def element_spec(self):
799    return self._element_spec
800
801
802@tf_export(v1=["data.experimental.CsvDataset"])
803class CsvDatasetV1(dataset_ops.DatasetV1Adapter):
804  """A Dataset comprising lines from one or more CSV files."""
805
806  @functools.wraps(CsvDatasetV2.__init__, ("__module__", "__name__"))
807  def __init__(self,
808               filenames,
809               record_defaults,
810               compression_type=None,
811               buffer_size=None,
812               header=False,
813               field_delim=",",
814               use_quote_delim=True,
815               na_value="",
816               select_cols=None):
817    """Creates a `CsvDataset` by reading and decoding CSV files.
818
819    The elements of this dataset correspond to records from the file(s).
820    RFC 4180 format is expected for CSV files
821    (https://tools.ietf.org/html/rfc4180)
822    Note that we allow leading and trailing spaces with int or float field.
823
824
825    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
826    different data types:
827    ```
828    abcdefg,4.28E10,5.55E6,12
829    hijklmn,-5.3E14,,2
830    ```
831
832    We can construct a CsvDataset from it as follows:
833
834    ```python
835     dataset = tf.data.experimental.CsvDataset(
836        "my_file*.csv",
837        [tf.float32,  # Required field, use dtype or empty tensor
838         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
839         tf.int32,  # Required field, use dtype or empty tensor
840         ],
841        select_cols=[1,2,3]  # Only parse last three columns
842    )
843    ```
844
845    The expected output of its iterations is:
846
847    ```python
848    for element in dataset:
849      print(element)
850
851    >> (4.28e10, 5.55e6, 12)
852    >> (-5.3e14, 0.0, 2)
853    ```
854
855    Args:
856      filenames: A `tf.string` tensor containing one or more filenames.
857      record_defaults: A list of default values for the CSV fields. Each item in
858        the list is either a valid CSV `DType` (float32, float64, int32, int64,
859        string), or a `Tensor` object with one of the above types. One per
860        column of CSV data, with either a scalar `Tensor` default value for the
861        column if it is optional, or `DType` or empty `Tensor` if required. If
862        both this and `select_columns` are specified, these must have the same
863        lengths, and `column_defaults` is assumed to be sorted in order of
864        increasing column index. If both this and 'exclude_cols' are specified,
865        the sum of lengths of record_defaults and exclude_cols should equal the
866        total number of columns in the CSV file.
867      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
868        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
869        compression.
870      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
871        to buffer while reading files. Defaults to 4MB.
872      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
873        have header line(s) that should be skipped when parsing. Defaults to
874        `False`.
875      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
876        character that separates fields in a record. Defaults to `","`.
877      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats double
878        quotation marks as regular characters inside of string fields (ignoring
879        RFC 4180, Section 2, Bullet 5). Defaults to `True`.
880      na_value: (Optional.) A `tf.string` scalar indicating a value that will be
881        treated as NA/NaN.
882      select_cols: (Optional.) A sorted list of column indices to select from
883        the input data. If specified, only this subset of columns will be
884        parsed. Defaults to parsing all columns. At most one of `select_cols`
885        and `exclude_cols` can be specified.
886    """
887    wrapped = CsvDatasetV2(filenames, record_defaults, compression_type,
888                           buffer_size, header, field_delim, use_quote_delim,
889                           na_value, select_cols)
890    super(CsvDatasetV1, self).__init__(wrapped)
891
892
893@tf_export("data.experimental.make_batched_features_dataset", v1=[])
894def make_batched_features_dataset_v2(file_pattern,
895                                     batch_size,
896                                     features,
897                                     reader=None,
898                                     label_key=None,
899                                     reader_args=None,
900                                     num_epochs=None,
901                                     shuffle=True,
902                                     shuffle_buffer_size=10000,
903                                     shuffle_seed=None,
904                                     prefetch_buffer_size=None,
905                                     reader_num_threads=None,
906                                     parser_num_threads=None,
907                                     sloppy_ordering=False,
908                                     drop_final_batch=False):
909  """Returns a `Dataset` of feature dictionaries from `Example` protos.
910
911  If label_key argument is provided, returns a `Dataset` of tuple
912  comprising of feature dictionaries and label.
913
914  Example:
915
916  ```
917  serialized_examples = [
918    features {
919      feature { key: "age" value { int64_list { value: [ 0 ] } } }
920      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
921      feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
922    },
923    features {
924      feature { key: "age" value { int64_list { value: [] } } }
925      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
926      feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
927    }
928  ]
929  ```
930
931  We can use arguments:
932
933  ```
934  features: {
935    "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
936    "gender": FixedLenFeature([], dtype=tf.string),
937    "kws": VarLenFeature(dtype=tf.string),
938  }
939  ```
940
941  And the expected output is:
942
943  ```python
944  {
945    "age": [[0], [-1]],
946    "gender": [["f"], ["f"]],
947    "kws": SparseTensor(
948      indices=[[0, 0], [0, 1], [1, 0]],
949      values=["code", "art", "sports"]
950      dense_shape=[2, 2]),
951  }
952  ```
953
954  Args:
955    file_pattern: List of files or patterns of file paths containing
956      `Example` records. See `tf.io.gfile.glob` for pattern rules.
957    batch_size: An int representing the number of records to combine
958      in a single batch.
959    features: A `dict` mapping feature keys to `FixedLenFeature` or
960      `VarLenFeature` values. See `tf.io.parse_example`.
961    reader: A function or class that can be
962      called with a `filenames` tensor and (optional) `reader_args` and returns
963      a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
964    label_key: (Optional) A string corresponding to the key labels are stored in
965      `tf.Examples`. If provided, it must be one of the `features` key,
966      otherwise results in `ValueError`.
967    reader_args: Additional arguments to pass to the reader class.
968    num_epochs: Integer specifying the number of times to read through the
969      dataset. If None, cycles through the dataset forever. Defaults to `None`.
970    shuffle: A boolean, indicates whether the input should be shuffled. Defaults
971      to `True`.
972    shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
973      ensures better shuffling but would increase memory usage and startup time.
974    shuffle_seed: Randomization seed to use for shuffling.
975    prefetch_buffer_size: Number of feature batches to prefetch in order to
976      improve performance. Recommended value is the number of batches consumed
977      per training step. Defaults to auto-tune.
978    reader_num_threads: Number of threads used to read `Example` records. If >1,
979      the results will be interleaved. Defaults to `1`.
980    parser_num_threads: Number of threads to use for parsing `Example` tensors
981      into a dictionary of `Feature` tensors. Defaults to `2`.
982    sloppy_ordering: If `True`, reading performance will be improved at
983      the cost of non-deterministic ordering. If `False`, the order of elements
984      produced is deterministic prior to shuffling (elements are still
985      randomized if `shuffle=True`. Note that if the seed is set, then order
986      of elements after shuffling is deterministic). Defaults to `False`.
987    drop_final_batch: If `True`, and the batch size does not evenly divide the
988      input dataset size, the final smaller batch will be dropped. Defaults to
989      `False`.
990
991  Returns:
992    A dataset of `dict` elements, (or a tuple of `dict` elements and label).
993    Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
994
995  Raises:
996    TypeError: If `reader` is of the wrong type.
997    ValueError: If `label_key` is not one of the `features` keys.
998  """
999  if reader is None:
1000    reader = core_readers.TFRecordDataset
1001
1002  if reader_num_threads is None:
1003    reader_num_threads = 1
1004  if parser_num_threads is None:
1005    parser_num_threads = 2
1006  if prefetch_buffer_size is None:
1007    prefetch_buffer_size = dataset_ops.AUTOTUNE
1008
1009  # Create dataset of all matching filenames
1010  dataset = dataset_ops.Dataset.list_files(
1011      file_pattern, shuffle=shuffle, seed=shuffle_seed)
1012
1013  if isinstance(reader, type) and issubclass(reader, io_ops.ReaderBase):
1014    raise TypeError("The `reader` argument must return a `Dataset` object. "
1015                    "`tf.ReaderBase` subclasses are not supported. For "
1016                    "example, pass `tf.data.TFRecordDataset` instead of "
1017                    "`tf.TFRecordReader`.")
1018
1019  # Read `Example` records from files as tensor objects.
1020  if reader_args is None:
1021    reader_args = []
1022
1023  if reader_num_threads == dataset_ops.AUTOTUNE:
1024    dataset = dataset.interleave(
1025        lambda filename: reader(filename, *reader_args),
1026        num_parallel_calls=reader_num_threads)
1027    options = options_lib.Options()
1028    options.deterministic = not sloppy_ordering
1029    dataset = dataset.with_options(options)
1030  else:
1031    # Read files sequentially (if reader_num_threads=1) or in parallel
1032    def apply_fn(dataset):
1033      return core_readers.ParallelInterleaveDataset(
1034          dataset,
1035          lambda filename: reader(filename, *reader_args),
1036          cycle_length=reader_num_threads,
1037          block_length=1,
1038          sloppy=sloppy_ordering,
1039          buffer_output_elements=None,
1040          prefetch_input_elements=None)
1041
1042    dataset = dataset.apply(apply_fn)
1043
1044  # Extract values if the `Example` tensors are stored as key-value tuples.
1045  if dataset_ops.get_legacy_output_types(dataset) == (
1046      dtypes.string, dtypes.string):
1047    dataset = dataset_ops.MapDataset(
1048        dataset, lambda _, v: v, use_inter_op_parallelism=False)
1049
1050  # Apply dataset repeat and shuffle transformations.
1051  dataset = _maybe_shuffle_and_repeat(
1052      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
1053
1054  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
1055  # improve the shape inference, because it makes the batch dimension static.
1056  # It is safe to do this because in that case we are repeating the input
1057  # indefinitely, and all batches will be full-sized.
1058  dataset = dataset.batch(
1059      batch_size, drop_remainder=drop_final_batch or num_epochs is None)
1060
1061  # Parse `Example` tensors to a dictionary of `Feature` tensors.
1062  dataset = dataset.apply(
1063      parsing_ops.parse_example_dataset(
1064          features, num_parallel_calls=parser_num_threads))
1065
1066  if label_key:
1067    if label_key not in features:
1068      raise ValueError(
1069          "The `label_key` provided (%r) must be one of the `features` keys." %
1070          label_key)
1071    dataset = dataset.map(lambda x: (x, x.pop(label_key)))
1072
1073  dataset = dataset.prefetch(prefetch_buffer_size)
1074  return dataset
1075
1076
1077@tf_export(v1=["data.experimental.make_batched_features_dataset"])
1078def make_batched_features_dataset_v1(file_pattern,  # pylint: disable=missing-docstring
1079                                     batch_size,
1080                                     features,
1081                                     reader=None,
1082                                     label_key=None,
1083                                     reader_args=None,
1084                                     num_epochs=None,
1085                                     shuffle=True,
1086                                     shuffle_buffer_size=10000,
1087                                     shuffle_seed=None,
1088                                     prefetch_buffer_size=None,
1089                                     reader_num_threads=None,
1090                                     parser_num_threads=None,
1091                                     sloppy_ordering=False,
1092                                     drop_final_batch=False):
1093  return dataset_ops.DatasetV1Adapter(make_batched_features_dataset_v2(
1094      file_pattern, batch_size, features, reader, label_key, reader_args,
1095      num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
1096      prefetch_buffer_size, reader_num_threads, parser_num_threads,
1097      sloppy_ordering, drop_final_batch))
1098make_batched_features_dataset_v1.__doc__ = (
1099    make_batched_features_dataset_v2.__doc__)
1100
1101
1102def _get_file_names(file_pattern, shuffle):
1103  """Parse list of file names from pattern, optionally shuffled.
1104
1105  Args:
1106    file_pattern: File glob pattern, or list of glob patterns.
1107    shuffle: Whether to shuffle the order of file names.
1108
1109  Returns:
1110    List of file names matching `file_pattern`.
1111
1112  Raises:
1113    ValueError: If `file_pattern` is empty, or pattern matches no files.
1114  """
1115  if isinstance(file_pattern, list):
1116    if not file_pattern:
1117      raise ValueError("File pattern is empty.")
1118    file_names = []
1119    for entry in file_pattern:
1120      file_names.extend(gfile.Glob(entry))
1121  else:
1122    file_names = list(gfile.Glob(file_pattern))
1123
1124  if not file_names:
1125    raise ValueError("No files match %s." % file_pattern)
1126
1127  # Sort files so it will be deterministic for unit tests.
1128  if not shuffle:
1129    file_names = sorted(file_names)
1130  return file_names
1131
1132
1133@tf_export("data.experimental.SqlDataset", v1=[])
1134class SqlDatasetV2(dataset_ops.DatasetSource):
1135  """A `Dataset` consisting of the results from a SQL query.
1136
1137  `SqlDataset` allows a user to read data from the result set of a SQL query.
1138  For example:
1139
1140  ```python
1141  dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
1142                                            "SELECT name, age FROM people",
1143                                            (tf.string, tf.int32))
1144  # Prints the rows of the result set of the above query.
1145  for element in dataset:
1146    print(element)
1147  ```
1148  """
1149
1150  def __init__(self, driver_name, data_source_name, query, output_types):
1151    """Creates a `SqlDataset`.
1152
1153    Args:
1154      driver_name: A 0-D `tf.string` tensor containing the database type.
1155        Currently, the only supported value is 'sqlite'.
1156      data_source_name: A 0-D `tf.string` tensor containing a connection string
1157        to connect to the database.
1158      query: A 0-D `tf.string` tensor containing the SQL query to execute.
1159      output_types: A tuple of `tf.DType` objects representing the types of the
1160        columns returned by `query`.
1161    """
1162    self._driver_name = ops.convert_to_tensor(
1163        driver_name, dtype=dtypes.string, name="driver_name")
1164    self._data_source_name = ops.convert_to_tensor(
1165        data_source_name, dtype=dtypes.string, name="data_source_name")
1166    self._query = ops.convert_to_tensor(
1167        query, dtype=dtypes.string, name="query")
1168    self._element_spec = nest.map_structure(
1169        lambda dtype: tensor_spec.TensorSpec([], dtype), output_types)
1170    variant_tensor = gen_experimental_dataset_ops.sql_dataset(
1171        self._driver_name, self._data_source_name, self._query,
1172        **self._flat_structure)
1173    super(SqlDatasetV2, self).__init__(variant_tensor)
1174
1175  @property
1176  def element_spec(self):
1177    return self._element_spec
1178
1179
1180@tf_export(v1=["data.experimental.SqlDataset"])
1181class SqlDatasetV1(dataset_ops.DatasetV1Adapter):
1182  """A `Dataset` consisting of the results from a SQL query."""
1183
1184  @functools.wraps(SqlDatasetV2.__init__)
1185  def __init__(self, driver_name, data_source_name, query, output_types):
1186    wrapped = SqlDatasetV2(driver_name, data_source_name, query, output_types)
1187    super(SqlDatasetV1, self).__init__(wrapped)
1188
1189
1190if tf2.enabled():
1191  CsvDataset = CsvDatasetV2
1192  SqlDataset = SqlDatasetV2
1193  make_batched_features_dataset = make_batched_features_dataset_v2
1194  make_csv_dataset = make_csv_dataset_v2
1195else:
1196  CsvDataset = CsvDatasetV1
1197  SqlDataset = SqlDatasetV1
1198  make_batched_features_dataset = make_batched_features_dataset_v1
1199  make_csv_dataset = make_csv_dataset_v1
1200