• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python API for save and loading a dataset."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import multiprocessing
22import os
23
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import gen_experimental_dataset_ops
28from tensorflow.python.platform import gfile
29from tensorflow.python.util import lazy_loader
30from tensorflow.python.util.tf_export import tf_export
31
32COMPRESSION_GZIP = "GZIP"
33COMPRESSION_SNAPPY = "NONE"
34DATASET_SPEC_FILENAME = "dataset_spec.pb"
35# TODO(b/176933539): Use the regular import.
36nested_structure_coder = lazy_loader.LazyLoader(
37    "nested_structure_coder", globals(),
38    "tensorflow.python.saved_model.nested_structure_coder")
39
40
41@tf_export("data.experimental.save", v1=[])
42def save(dataset, path, compression=None, shard_func=None):
43  """Saves the content of the given dataset.
44
45  Example usage:
46
47  >>> import tempfile
48  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
49  >>> # Save a dataset
50  >>> dataset = tf.data.Dataset.range(2)
51  >>> tf.data.experimental.save(dataset, path)
52  >>> new_dataset = tf.data.experimental.load(path,
53  ...     tf.TensorSpec(shape=(), dtype=tf.int64))
54  >>> for elem in new_dataset:
55  ...   print(elem)
56  tf.Tensor(0, shape=(), dtype=int64)
57  tf.Tensor(1, shape=(), dtype=int64)
58
59  The saved dataset is saved in multiple file "shards". By default, the dataset
60  output is divided to shards in a round-robin fashion but custom sharding can
61  be specified via the `shard_func` function. For example, you can save the
62  dataset to using a single shard as follows:
63
64  ```python
65  dataset = make_dataset()
66  def custom_shard_func(element):
67    return 0
68  dataset = tf.data.experimental.save(
69      path="/path/to/data", ..., shard_func=custom_shard_func)
70  ```
71
72  NOTE: The directory layout and file format used for saving the dataset is
73  considered an implementation detail and may change. For this reason, datasets
74  saved through `tf.data.experimental.save` should only be consumed through
75  `tf.data.experimental.load`, which is guaranteed to be backwards compatible.
76
77  Args:
78    dataset: The dataset to save.
79    path: Required. A directory to use for saving the dataset.
80    compression: Optional. The algorithm to use to compress data when writing
81      it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
82    shard_func: Optional. A function to control the mapping of dataset elements
83      to file shards. The function is expected to map elements of the input
84      dataset to int64 shard IDs. If present, the function will be traced and
85      executed as graph computation.
86  """
87
88  if shard_func is None:
89    use_shard_func = False
90    shard_func = lambda *x: None  # a dummy function that will not be used
91  else:
92    use_shard_func = True
93
94  wrapped_func = dataset_ops.StructuredFunctionWrapper(
95      shard_func,
96      "save()",
97      input_structure=dataset.element_spec,
98      add_to_graph=False)
99
100  coder = nested_structure_coder.StructureCoder()
101  encoded = coder.encode_structure(dataset.element_spec)
102  gfile.MakeDirs(path)
103  with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f:
104    f.write(encoded.SerializeToString())
105
106  path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
107  shard_func = wrapped_func.function
108  shard_func.add_to_graph(ops.get_default_graph())
109
110  # pylint: disable=protected-access
111  dataset = dataset._apply_options()
112  gen_experimental_dataset_ops.save_dataset(
113      dataset._variant_tensor,
114      path=path,
115      shard_func_other_args=shard_func.captured_inputs,
116      compression=compression,
117      shard_func=shard_func,
118      use_shard_func=use_shard_func)
119
120
121class _LoadDataset(dataset_ops.DatasetSource):
122  """A dataset that loads previously saved dataset."""
123
124  def __init__(self, path, element_spec=None, compression=None,
125               reader_func=None):
126
127    if reader_func is None:
128      reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
129          lambda x: x,
130          cycle_length=multiprocessing.cpu_count(),
131          num_parallel_calls=dataset_ops.AUTOTUNE)
132
133    self._path = path
134    if element_spec is None:
135      with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f:
136        encoded_spec = f.read()
137      struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
138      struct_pb.ParseFromString(encoded_spec)
139      coder = nested_structure_coder.StructureCoder()
140      spec = coder.decode_proto(struct_pb)
141      self._element_spec = spec
142    else:
143      self._element_spec = element_spec
144    self._compression = compression
145    self._reader_func = dataset_ops.StructuredFunctionWrapper(
146        reader_func,
147        "load()",
148        # Dataset of datasets of input elements
149        input_structure=dataset_ops.DatasetSpec(
150            dataset_ops.DatasetSpec(self._element_spec)))
151
152    variant_tensor = gen_experimental_dataset_ops.load_dataset(
153        path,
154        reader_func_other_args=self._reader_func.function.captured_inputs,
155        compression=compression,
156        reader_func=self._reader_func.function,
157        **self._flat_structure)
158    super(_LoadDataset, self).__init__(variant_tensor)
159
160  def _functions(self):
161    return [self._reader_func]
162
163  @property
164  def element_spec(self):
165    return self._element_spec
166
167
168@tf_export("data.experimental.load", v1=[])
169def load(path, element_spec=None, compression=None, reader_func=None):
170  """Loads a previously saved dataset.
171
172  Example usage:
173
174  >>> import tempfile
175  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
176  >>> # Save a dataset
177  >>> dataset = tf.data.Dataset.range(2)
178  >>> tf.data.experimental.save(dataset, path)
179  >>> new_dataset = tf.data.experimental.load(path)
180  >>> for elem in new_dataset:
181  ...   print(elem)
182  tf.Tensor(0, shape=(), dtype=int64)
183  tf.Tensor(1, shape=(), dtype=int64)
184
185
186  Note that to load a previously saved dataset, you need to specify
187  `element_spec` -- a type signature of the elements of the saved dataset, which
188  can be obtained via `tf.data.Dataset.element_spec`. This requirement exists so
189  that shape inference of the loaded dataset does not need to perform I/O.
190
191  If the default option of sharding the saved dataset was used, the element
192  order of the saved dataset will be preserved when loading it.
193
194  The `reader_func` argument can be used to specify a custom order in which
195  elements should be loaded from the individual shards. The `reader_func` is
196  expected to take a single argument -- a dataset of datasets, each containing
197  elements of one of the shards -- and return a dataset of elements. For
198  example, the order of shards can be shuffled when loading them as follows:
199
200  ```python
201  def custom_reader_func(datasets):
202    datasets = datasets.shuffle(NUM_SHARDS)
203    return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
204
205  dataset = tf.data.experimental.load(
206      path="/path/to/data", ..., reader_func=custom_reader_func)
207  ```
208
209  Args:
210    path: Required. A path pointing to a previously saved dataset.
211    element_spec: Optional. A nested structure of `tf.TypeSpec` objects matching
212      the structure of an element of the saved dataset and specifying the type
213      of individual element components. If not provided, the nested structure of
214      `tf.TypeSpec` saved with the saved dataset is used.
215    compression: Optional. The algorithm to use to decompress the data when
216      reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
217    reader_func: Optional. A function to control how to read data from shards.
218      If present, the function will be traced and executed as graph computation.
219
220  Returns:
221    A `tf.data.Dataset` instance.
222
223  Raises:
224    FileNotFoundError: If `element_spec` is not specified and the saved nested
225      structure of `tf.TypeSpec` can not be located with the saved dataset.
226  """
227
228  return _LoadDataset(
229      path=path,
230      element_spec=element_spec,
231      compression=compression,
232      reader_func=reader_func)
233