1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Dataset snapshot and related functionality.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import random_seed 24from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 25from tensorflow.python.util import deprecation 26from tensorflow.python.util.tf_export import tf_export 27 28COMPRESSION_GZIP = "GZIP" 29COMPRESSION_SNAPPY = "SNAPPY" 30COMPRESSION_NONE = None 31 32 33class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): 34 """A Dataset that captures a snapshot or reads from a snapshot.""" 35 36 def __init__(self, 37 input_dataset, 38 path, 39 compression=None, 40 reader_path_prefix=None, 41 writer_path_prefix=None, 42 shard_size_bytes=None, 43 pending_snapshot_expiry_seconds=None, 44 num_reader_threads=None, 45 reader_buffer_size=None, 46 num_writer_threads=None, 47 writer_buffer_size=None, 48 shuffle_on_read=None, 49 shuffle_seed=None, 50 mode=None, 51 snapshot_name=None): 52 53 self._compression = compression if compression is not None else "" 54 self._reader_path_prefix = ( 55 reader_path_prefix if reader_path_prefix is not None else "") 56 self._writer_path_prefix = ( 57 writer_path_prefix if writer_path_prefix is not None else "") 58 self._shard_size_bytes = ( 59 shard_size_bytes if shard_size_bytes is not None else -1) 60 self._pending_snapshot_expiry_seconds = ( 61 pending_snapshot_expiry_seconds 62 if pending_snapshot_expiry_seconds is not None else -1) 63 self._num_reader_threads = ( 64 num_reader_threads if num_reader_threads is not None else -1) 65 self._reader_buffer_size = ( 66 reader_buffer_size if reader_buffer_size is not None else -1) 67 self._num_writer_threads = ( 68 num_writer_threads if num_writer_threads is not None else -1) 69 self._writer_buffer_size = ( 70 writer_buffer_size if writer_buffer_size is not None else -1) 71 self._shuffle_on_read = ( 72 shuffle_on_read if shuffle_on_read is not None else False) 73 self._mode = (mode if mode is not None else "auto") 74 self._snapshot_name = (snapshot_name if snapshot_name is not None else "") 75 76 self._seed, self._seed2 = random_seed.get_seed(shuffle_seed) 77 78 self._input_dataset = input_dataset 79 self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") 80 81 variant_tensor = ged_ops.snapshot_dataset( 82 self._input_dataset._variant_tensor, # pylint: disable=protected-access 83 path=self._path, 84 compression=self._compression, 85 reader_path_prefix=self._reader_path_prefix, 86 writer_path_prefix=self._writer_path_prefix, 87 shard_size_bytes=self._shard_size_bytes, 88 pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds, 89 num_reader_threads=self._num_reader_threads, 90 reader_buffer_size=self._reader_buffer_size, 91 num_writer_threads=self._num_writer_threads, 92 writer_buffer_size=self._writer_buffer_size, 93 shuffle_on_read=self._shuffle_on_read, 94 seed=self._seed, 95 seed2=self._seed2, 96 mode=self._mode, 97 snapshot_name=self._snapshot_name, 98 **self._flat_structure) 99 100 super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor) 101 102 103@deprecation.deprecated( 104 None, "Use `tf.data.experimental.snapshot(...)` instead.") 105def legacy_snapshot(path, 106 compression=None, 107 reader_path_prefix=None, 108 writer_path_prefix=None, 109 shard_size_bytes=None, 110 pending_snapshot_expiry_seconds=None, 111 num_reader_threads=None, 112 reader_buffer_size=None, 113 num_writer_threads=None, 114 writer_buffer_size=None, 115 shuffle_on_read=None, 116 shuffle_seed=None, 117 mode=None, 118 snapshot_name=None): 119 """Writes to/reads from a snapshot of a dataset. 120 121 This function attempts to determine whether a valid snapshot exists at the 122 `path`, and reads from the snapshot if so. If not, it will run the 123 preprocessing pipeline as usual, and write out a snapshot of the data 124 processed for future use. 125 126 Args: 127 path: A directory where we want to save our snapshots and/or read from a 128 previously saved snapshot. 129 compression: The type of compression to apply to the Dataset. Currently 130 supports "GZIP" or None. Defaults to None (no compression). 131 reader_path_prefix: A prefix to add to the path when reading from snapshots. 132 Defaults to None. 133 writer_path_prefix: A prefix to add to the path when writing to snapshots. 134 Defaults to None. 135 shard_size_bytes: The size of each shard to be written by the snapshot 136 dataset op. Defaults to 10 GiB. 137 pending_snapshot_expiry_seconds: How long to wait (in seconds) before the 138 snapshot op considers a previously unfinished snapshot to be stale. 139 num_reader_threads: Number of threads to parallelize reading from snapshot. 140 Especially useful if compression is turned on since the decompression 141 operation tends to be intensive. Defaults to 1. If > 1, then this might 142 introduce non-determinism i.e. the order in which the elements are read 143 from the snapshot are different from the order they're written. 144 reader_buffer_size: Maximum number of elements we can prefetch reading from 145 the snapshot. Defaults to 1. Increasing this might improve performance but 146 will increase memory consumption. 147 num_writer_threads: Number of threads to parallelize writing from snapshot. 148 We'll open up `num_writer_threads` files and write to them in parallel. 149 Especially useful if compression is turned on since the compression 150 operation tends to be intensive. Defaults to 1. If > 1, then this might 151 introduce non-determinism i.e. the order in which the elements are read 152 from the upstream iterator are different from the order they're written. 153 writer_buffer_size: Maximum number of pipeline elements to fill up the 154 buffer before writing them out using `num_writer_threads`. 155 shuffle_on_read: If this is True, then the order in which examples are 156 produced when reading from a snapshot will be random. Defaults to False. 157 shuffle_seed: Optional. If shuffle_seed is set, the random number generator 158 used for shuffling (when shuffle_on_read is turned on) is seeded by the 159 given seed. Otherwise, it is seeded by a random seed that differs for 160 every run. 161 mode: The mode at which snapshot should operate. Valid options are "auto", 162 "read", "write", and "passthrough". The default mode is "auto", where the 163 snapshot op will automatically determine what mode to operate in. 164 snapshot_name: If set, use the supplied string as a named snapshot name 165 instead of introspecting the data pipeline and automatically generating a 166 unique identifier for the snapshot. 167 168 Returns: 169 A `Dataset` transformation function, which can be passed to 170 `tf.data.Dataset.apply`. 171 """ 172 173 def _apply_fn(dataset): 174 return _LegacySnapshotDataset( 175 input_dataset=dataset, 176 path=path, 177 compression=compression, 178 reader_path_prefix=reader_path_prefix, 179 writer_path_prefix=writer_path_prefix, 180 shard_size_bytes=shard_size_bytes, 181 pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, 182 num_reader_threads=num_reader_threads, 183 reader_buffer_size=reader_buffer_size, 184 num_writer_threads=num_writer_threads, 185 writer_buffer_size=writer_buffer_size, 186 shuffle_on_read=shuffle_on_read, 187 shuffle_seed=shuffle_seed, 188 mode=mode, 189 snapshot_name=snapshot_name) 190 191 return _apply_fn 192 193 194@deprecation.deprecated(None, "Use `tf.data.Dataset.snapshot(...)`.") 195@tf_export("data.experimental.snapshot") 196def snapshot(path, compression="AUTO", reader_func=None, shard_func=None): 197 """API to persist the output of the input dataset. 198 199 The snapshot API allows users to transparently persist the output of their 200 preprocessing pipeline to disk, and materialize the pre-processed data on a 201 different training run. 202 203 This API enables repeated preprocessing steps to be consolidated, and allows 204 re-use of already processed data, trading off disk storage and network 205 bandwidth for freeing up more valuable CPU resources and accelerator compute 206 time. 207 208 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md 209 has detailed design documentation of this feature. 210 211 Users can specify various options to control the behavior of snapshot, 212 including how snapshots are read from and written to by passing in 213 user-defined functions to the `reader_func` and `shard_func` parameters. 214 215 `shard_func` is a user specified function that maps input elements to snapshot 216 shards. 217 218 Users may want to specify this function to control how snapshot files should 219 be written to disk. Below is an example of how a potential shard_func could 220 be written. 221 222 ```python 223 dataset = ... 224 dataset = dataset.enumerate() 225 dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir", 226 shard_func=lambda x, y: x % NUM_SHARDS, ...)) 227 dataset = dataset.map(lambda x, y: y) 228 ``` 229 230 `reader_func` is a user specified function that accepts a single argument: 231 (1) a Dataset of Datasets, each representing a "split" of elements of the 232 original dataset. The cardinality of the input dataset matches the 233 number of the shards specified in the `shard_func` (see above). The function 234 should return a Dataset of elements of the original dataset. 235 236 Users may want specify this function to control how snapshot files should be 237 read from disk, including the amount of shuffling and parallelism. 238 239 Here is an example of a standard reader function a user can define. This 240 function enables both dataset shuffling and parallel reading of datasets: 241 242 ```python 243 def user_reader_func(datasets): 244 # shuffle the datasets splits 245 datasets = datasets.shuffle(NUM_CORES) 246 # read datasets in parallel and interleave their elements 247 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 248 249 dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir", 250 reader_func=user_reader_func)) 251 ``` 252 253 By default, snapshot parallelizes reads by the number of cores available on 254 the system, but will not attempt to shuffle the data. 255 256 Args: 257 path: Required. A directory to use for storing / loading the snapshot to / 258 from. 259 compression: Optional. The type of compression to apply to the snapshot 260 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None. 261 Defaults to AUTO, which attempts to pick an appropriate compression 262 algorithm for the dataset. 263 reader_func: Optional. A function to control how to read data from snapshot 264 shards. 265 shard_func: Optional. A function to control how to shard data when writing a 266 snapshot. 267 268 Returns: 269 A `Dataset` transformation function, which can be passed to 270 `tf.data.Dataset.apply`. 271 """ 272 273 def _apply_fn(dataset): 274 """Actual dataset transformation.""" 275 return dataset.snapshot( 276 path=path, 277 compression=compression, 278 reader_func=reader_func, 279 shard_func=shard_func) 280 281 return _apply_fn 282