• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dataset snapshot and related functionality."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import random_seed
24from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
27
28COMPRESSION_GZIP = "GZIP"
29COMPRESSION_SNAPPY = "SNAPPY"
30COMPRESSION_NONE = None
31
32
33class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
34  """A Dataset that captures a snapshot or reads from a snapshot."""
35
36  def __init__(self,
37               input_dataset,
38               path,
39               compression=None,
40               reader_path_prefix=None,
41               writer_path_prefix=None,
42               shard_size_bytes=None,
43               pending_snapshot_expiry_seconds=None,
44               num_reader_threads=None,
45               reader_buffer_size=None,
46               num_writer_threads=None,
47               writer_buffer_size=None,
48               shuffle_on_read=None,
49               shuffle_seed=None,
50               mode=None,
51               snapshot_name=None):
52
53    self._compression = compression if compression is not None else ""
54    self._reader_path_prefix = (
55        reader_path_prefix if reader_path_prefix is not None else "")
56    self._writer_path_prefix = (
57        writer_path_prefix if writer_path_prefix is not None else "")
58    self._shard_size_bytes = (
59        shard_size_bytes if shard_size_bytes is not None else -1)
60    self._pending_snapshot_expiry_seconds = (
61        pending_snapshot_expiry_seconds
62        if pending_snapshot_expiry_seconds is not None else -1)
63    self._num_reader_threads = (
64        num_reader_threads if num_reader_threads is not None else -1)
65    self._reader_buffer_size = (
66        reader_buffer_size if reader_buffer_size is not None else -1)
67    self._num_writer_threads = (
68        num_writer_threads if num_writer_threads is not None else -1)
69    self._writer_buffer_size = (
70        writer_buffer_size if writer_buffer_size is not None else -1)
71    self._shuffle_on_read = (
72        shuffle_on_read if shuffle_on_read is not None else False)
73    self._mode = (mode if mode is not None else "auto")
74    self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
75
76    self._seed, self._seed2 = random_seed.get_seed(shuffle_seed)
77
78    self._input_dataset = input_dataset
79    self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
80
81    variant_tensor = ged_ops.snapshot_dataset(
82        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
83        path=self._path,
84        compression=self._compression,
85        reader_path_prefix=self._reader_path_prefix,
86        writer_path_prefix=self._writer_path_prefix,
87        shard_size_bytes=self._shard_size_bytes,
88        pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
89        num_reader_threads=self._num_reader_threads,
90        reader_buffer_size=self._reader_buffer_size,
91        num_writer_threads=self._num_writer_threads,
92        writer_buffer_size=self._writer_buffer_size,
93        shuffle_on_read=self._shuffle_on_read,
94        seed=self._seed,
95        seed2=self._seed2,
96        mode=self._mode,
97        snapshot_name=self._snapshot_name,
98        **self._flat_structure)
99
100    super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor)
101
102
103@deprecation.deprecated(
104    None, "Use `tf.data.experimental.snapshot(...)` instead.")
105def legacy_snapshot(path,
106                    compression=None,
107                    reader_path_prefix=None,
108                    writer_path_prefix=None,
109                    shard_size_bytes=None,
110                    pending_snapshot_expiry_seconds=None,
111                    num_reader_threads=None,
112                    reader_buffer_size=None,
113                    num_writer_threads=None,
114                    writer_buffer_size=None,
115                    shuffle_on_read=None,
116                    shuffle_seed=None,
117                    mode=None,
118                    snapshot_name=None):
119  """Writes to/reads from a snapshot of a dataset.
120
121  This function attempts to determine whether a valid snapshot exists at the
122  `path`, and reads from the snapshot if so. If not, it will run the
123  preprocessing pipeline as usual, and write out a snapshot of the data
124  processed for future use.
125
126  Args:
127    path: A directory where we want to save our snapshots and/or read from a
128      previously saved snapshot.
129    compression: The type of compression to apply to the Dataset. Currently
130      supports "GZIP" or None. Defaults to None (no compression).
131    reader_path_prefix: A prefix to add to the path when reading from snapshots.
132      Defaults to None.
133    writer_path_prefix: A prefix to add to the path when writing to snapshots.
134      Defaults to None.
135    shard_size_bytes: The size of each shard to be written by the snapshot
136      dataset op. Defaults to 10 GiB.
137    pending_snapshot_expiry_seconds: How long to wait (in seconds) before the
138      snapshot op considers a previously unfinished snapshot to be stale.
139    num_reader_threads: Number of threads to parallelize reading from snapshot.
140      Especially useful if compression is turned on since the decompression
141      operation tends to be intensive. Defaults to 1. If > 1, then this might
142      introduce non-determinism i.e. the order in which the elements are read
143      from the snapshot are different from the order they're written.
144    reader_buffer_size: Maximum number of elements we can prefetch reading from
145      the snapshot. Defaults to 1. Increasing this might improve performance but
146      will increase memory consumption.
147    num_writer_threads: Number of threads to parallelize writing from snapshot.
148      We'll open up `num_writer_threads` files and write to them in parallel.
149      Especially useful if compression is turned on since the compression
150      operation tends to be intensive. Defaults to 1. If > 1, then this might
151      introduce non-determinism i.e. the order in which the elements are read
152      from the upstream iterator are different from the order they're written.
153    writer_buffer_size: Maximum number of pipeline elements to fill up the
154      buffer before writing them out using `num_writer_threads`.
155    shuffle_on_read: If this is True, then the order in which examples are
156      produced when reading from a snapshot will be random. Defaults to False.
157    shuffle_seed: Optional. If shuffle_seed is set, the random number generator
158      used for shuffling (when shuffle_on_read is turned on) is seeded by the
159      given seed. Otherwise, it is seeded by a random seed that differs for
160      every run.
161    mode: The mode at which snapshot should operate. Valid options are "auto",
162      "read", "write", and "passthrough". The default mode is "auto", where the
163      snapshot op will automatically determine what mode to operate in.
164    snapshot_name: If set, use the supplied string as a named snapshot name
165      instead of introspecting the data pipeline and automatically generating a
166      unique identifier for the snapshot.
167
168  Returns:
169    A `Dataset` transformation function, which can be passed to
170    `tf.data.Dataset.apply`.
171  """
172
173  def _apply_fn(dataset):
174    return _LegacySnapshotDataset(
175        input_dataset=dataset,
176        path=path,
177        compression=compression,
178        reader_path_prefix=reader_path_prefix,
179        writer_path_prefix=writer_path_prefix,
180        shard_size_bytes=shard_size_bytes,
181        pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
182        num_reader_threads=num_reader_threads,
183        reader_buffer_size=reader_buffer_size,
184        num_writer_threads=num_writer_threads,
185        writer_buffer_size=writer_buffer_size,
186        shuffle_on_read=shuffle_on_read,
187        shuffle_seed=shuffle_seed,
188        mode=mode,
189        snapshot_name=snapshot_name)
190
191  return _apply_fn
192
193
194@deprecation.deprecated(None, "Use `tf.data.Dataset.snapshot(...)`.")
195@tf_export("data.experimental.snapshot")
196def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
197  """API to persist the output of the input dataset.
198
199  The snapshot API allows users to transparently persist the output of their
200  preprocessing pipeline to disk, and materialize the pre-processed data on a
201  different training run.
202
203  This API enables repeated preprocessing steps to be consolidated, and allows
204  re-use of already processed data, trading off disk storage and network
205  bandwidth for freeing up more valuable CPU resources and accelerator compute
206  time.
207
208  https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
209  has detailed design documentation of this feature.
210
211  Users can specify various options to control the behavior of snapshot,
212  including how snapshots are read from and written to by passing in
213  user-defined functions to the `reader_func` and `shard_func` parameters.
214
215  `shard_func` is a user specified function that maps input elements to snapshot
216  shards.
217
218  Users may want to specify this function to control how snapshot files should
219  be written to disk. Below is an example of how a potential shard_func could
220  be written.
221
222  ```python
223  dataset = ...
224  dataset = dataset.enumerate()
225  dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir",
226      shard_func=lambda x, y: x % NUM_SHARDS, ...))
227  dataset = dataset.map(lambda x, y: y)
228  ```
229
230  `reader_func` is a user specified function that accepts a single argument:
231  (1) a Dataset of Datasets, each representing a "split" of elements of the
232  original dataset. The cardinality of the input dataset matches the
233  number of the shards specified in the `shard_func` (see above). The function
234  should return a Dataset of elements of the original dataset.
235
236  Users may want specify this function to control how snapshot files should be
237  read from disk, including the amount of shuffling and parallelism.
238
239  Here is an example of a standard reader function a user can define. This
240  function enables both dataset shuffling and parallel reading of datasets:
241
242  ```python
243  def user_reader_func(datasets):
244    # shuffle the datasets splits
245    datasets = datasets.shuffle(NUM_CORES)
246    # read datasets in parallel and interleave their elements
247    return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
248
249  dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir",
250      reader_func=user_reader_func))
251  ```
252
253  By default, snapshot parallelizes reads by the number of cores available on
254  the system, but will not attempt to shuffle the data.
255
256  Args:
257    path: Required. A directory to use for storing / loading the snapshot to /
258      from.
259    compression: Optional. The type of compression to apply to the snapshot
260      written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
261      Defaults to AUTO, which attempts to pick an appropriate compression
262      algorithm for the dataset.
263    reader_func: Optional. A function to control how to read data from snapshot
264      shards.
265    shard_func: Optional. A function to control how to shard data when writing a
266      snapshot.
267
268  Returns:
269    A `Dataset` transformation function, which can be passed to
270    `tf.data.Dataset.apply`.
271  """
272
273  def _apply_fn(dataset):
274    """Actual dataset transformation."""
275    return dataset.snapshot(
276        path=path,
277        compression=compression,
278        reader_func=reader_func,
279        shard_func=shard_func)
280
281  return _apply_fn
282