• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for tf.data writers."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import convert
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.ops import gen_experimental_dataset_ops
26from tensorflow.python.util import deprecation
27from tensorflow.python.util.tf_export import tf_export
28
29
30@tf_export("data.experimental.TFRecordWriter")
31@deprecation.deprecated(
32    None, "To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save "
33    "and load the contents of a dataset, use `tf.data.experimental.save` "
34    "and `tf.data.experimental.load`")
35class TFRecordWriter(object):
36  """Writes a dataset to a TFRecord file.
37
38  The elements of the dataset must be scalar strings. To serialize dataset
39  elements as strings, you can use the `tf.io.serialize_tensor` function.
40
41  ```python
42  dataset = tf.data.Dataset.range(3)
43  dataset = dataset.map(tf.io.serialize_tensor)
44  writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord")
45  writer.write(dataset)
46  ```
47
48  To read back the elements, use `TFRecordDataset`.
49
50  ```python
51  dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord")
52  dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64))
53  ```
54
55  To shard a `dataset` across multiple TFRecord files:
56
57  ```python
58  dataset = ... # dataset to be written
59
60  def reduce_func(key, dataset):
61    filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)])
62    writer = tf.data.experimental.TFRecordWriter(filename)
63    writer.write(dataset.map(lambda _, x: x))
64    return tf.data.Dataset.from_tensors(filename)
65
66  dataset = dataset.enumerate()
67  dataset = dataset.apply(tf.data.experimental.group_by_window(
68    lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max
69  ))
70
71  # Iterate through the dataset to trigger data writing.
72  for _ in dataset:
73    pass
74  ```
75  """
76
77  def __init__(self, filename, compression_type=None):
78    """Initializes a `TFRecordWriter`.
79
80    Args:
81      filename: a string path indicating where to write the TFRecord data.
82      compression_type: (Optional.) a string indicating what type of compression
83        to use when writing the file. See `tf.io.TFRecordCompressionType` for
84        what types of compression are available. Defaults to `None`.
85    """
86    self._filename = ops.convert_to_tensor(
87        filename, dtypes.string, name="filename")
88    self._compression_type = convert.optional_param_to_tensor(
89        "compression_type",
90        compression_type,
91        argument_default="",
92        argument_dtype=dtypes.string)
93
94  def write(self, dataset):
95    """Writes a dataset to a TFRecord file.
96
97    An operation that writes the content of the specified dataset to the file
98    specified in the constructor.
99
100    If the file exists, it will be overwritten.
101
102    Args:
103      dataset: a `tf.data.Dataset` whose elements are to be written to a file
104
105    Returns:
106      In graph mode, this returns an operation which when executed performs the
107      write. In eager mode, the write is performed by the method itself and
108      there is no return value.
109
110    Raises
111      TypeError: if `dataset` is not a `tf.data.Dataset`.
112      TypeError: if the elements produced by the dataset are not scalar strings.
113    """
114    if not isinstance(dataset, dataset_ops.DatasetV2):
115      raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
116    if not dataset_ops.get_structure(dataset).is_compatible_with(
117        tensor_spec.TensorSpec([], dtypes.string)):
118      raise TypeError(
119          "`dataset` must produce scalar `DT_STRING` tensors whereas it "
120          "produces shape {0} and types {1}".format(
121              dataset_ops.get_legacy_output_shapes(dataset),
122              dataset_ops.get_legacy_output_types(dataset)))
123    # pylint: disable=protected-access
124    dataset = dataset._apply_debug_options()
125    return gen_experimental_dataset_ops.dataset_to_tf_record(
126        dataset._variant_tensor, self._filename, self._compression_type)
127