1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python API for save and loading a dataset.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import multiprocessing 22import os 23 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import gen_experimental_dataset_ops 28from tensorflow.python.platform import gfile 29from tensorflow.python.util import lazy_loader 30from tensorflow.python.util.tf_export import tf_export 31 32COMPRESSION_GZIP = "GZIP" 33COMPRESSION_SNAPPY = "NONE" 34DATASET_SPEC_FILENAME = "dataset_spec.pb" 35# TODO(b/176933539): Use the regular import. 36nested_structure_coder = lazy_loader.LazyLoader( 37 "nested_structure_coder", globals(), 38 "tensorflow.python.saved_model.nested_structure_coder") 39 40 41@tf_export("data.experimental.save", v1=[]) 42def save(dataset, path, compression=None, shard_func=None): 43 """Saves the content of the given dataset. 44 45 Example usage: 46 47 >>> import tempfile 48 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 49 >>> # Save a dataset 50 >>> dataset = tf.data.Dataset.range(2) 51 >>> tf.data.experimental.save(dataset, path) 52 >>> new_dataset = tf.data.experimental.load(path, 53 ... tf.TensorSpec(shape=(), dtype=tf.int64)) 54 >>> for elem in new_dataset: 55 ... print(elem) 56 tf.Tensor(0, shape=(), dtype=int64) 57 tf.Tensor(1, shape=(), dtype=int64) 58 59 The saved dataset is saved in multiple file "shards". By default, the dataset 60 output is divided to shards in a round-robin fashion but custom sharding can 61 be specified via the `shard_func` function. For example, you can save the 62 dataset to using a single shard as follows: 63 64 ```python 65 dataset = make_dataset() 66 def custom_shard_func(element): 67 return 0 68 dataset = tf.data.experimental.save( 69 path="/path/to/data", ..., shard_func=custom_shard_func) 70 ``` 71 72 NOTE: The directory layout and file format used for saving the dataset is 73 considered an implementation detail and may change. For this reason, datasets 74 saved through `tf.data.experimental.save` should only be consumed through 75 `tf.data.experimental.load`, which is guaranteed to be backwards compatible. 76 77 Args: 78 dataset: The dataset to save. 79 path: Required. A directory to use for saving the dataset. 80 compression: Optional. The algorithm to use to compress data when writing 81 it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 82 shard_func: Optional. A function to control the mapping of dataset elements 83 to file shards. The function is expected to map elements of the input 84 dataset to int64 shard IDs. If present, the function will be traced and 85 executed as graph computation. 86 """ 87 88 if shard_func is None: 89 use_shard_func = False 90 shard_func = lambda *x: None # a dummy function that will not be used 91 else: 92 use_shard_func = True 93 94 wrapped_func = dataset_ops.StructuredFunctionWrapper( 95 shard_func, 96 "save()", 97 input_structure=dataset.element_spec, 98 add_to_graph=False) 99 100 coder = nested_structure_coder.StructureCoder() 101 encoded = coder.encode_structure(dataset.element_spec) 102 gfile.MakeDirs(path) 103 with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f: 104 f.write(encoded.SerializeToString()) 105 106 path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") 107 shard_func = wrapped_func.function 108 shard_func.add_to_graph(ops.get_default_graph()) 109 110 # pylint: disable=protected-access 111 dataset = dataset._apply_options() 112 gen_experimental_dataset_ops.save_dataset( 113 dataset._variant_tensor, 114 path=path, 115 shard_func_other_args=shard_func.captured_inputs, 116 compression=compression, 117 shard_func=shard_func, 118 use_shard_func=use_shard_func) 119 120 121class _LoadDataset(dataset_ops.DatasetSource): 122 """A dataset that loads previously saved dataset.""" 123 124 def __init__(self, path, element_spec=None, compression=None, 125 reader_func=None): 126 127 if reader_func is None: 128 reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda 129 lambda x: x, 130 cycle_length=multiprocessing.cpu_count(), 131 num_parallel_calls=dataset_ops.AUTOTUNE) 132 133 self._path = path 134 if element_spec is None: 135 with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f: 136 encoded_spec = f.read() 137 struct_pb = nested_structure_coder.struct_pb2.StructuredValue() 138 struct_pb.ParseFromString(encoded_spec) 139 coder = nested_structure_coder.StructureCoder() 140 spec = coder.decode_proto(struct_pb) 141 self._element_spec = spec 142 else: 143 self._element_spec = element_spec 144 self._compression = compression 145 self._reader_func = dataset_ops.StructuredFunctionWrapper( 146 reader_func, 147 "load()", 148 # Dataset of datasets of input elements 149 input_structure=dataset_ops.DatasetSpec( 150 dataset_ops.DatasetSpec(self._element_spec))) 151 152 variant_tensor = gen_experimental_dataset_ops.load_dataset( 153 path, 154 reader_func_other_args=self._reader_func.function.captured_inputs, 155 compression=compression, 156 reader_func=self._reader_func.function, 157 **self._flat_structure) 158 super(_LoadDataset, self).__init__(variant_tensor) 159 160 def _functions(self): 161 return [self._reader_func] 162 163 @property 164 def element_spec(self): 165 return self._element_spec 166 167 168@tf_export("data.experimental.load", v1=[]) 169def load(path, element_spec=None, compression=None, reader_func=None): 170 """Loads a previously saved dataset. 171 172 Example usage: 173 174 >>> import tempfile 175 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 176 >>> # Save a dataset 177 >>> dataset = tf.data.Dataset.range(2) 178 >>> tf.data.experimental.save(dataset, path) 179 >>> new_dataset = tf.data.experimental.load(path) 180 >>> for elem in new_dataset: 181 ... print(elem) 182 tf.Tensor(0, shape=(), dtype=int64) 183 tf.Tensor(1, shape=(), dtype=int64) 184 185 186 Note that to load a previously saved dataset, you need to specify 187 `element_spec` -- a type signature of the elements of the saved dataset, which 188 can be obtained via `tf.data.Dataset.element_spec`. This requirement exists so 189 that shape inference of the loaded dataset does not need to perform I/O. 190 191 If the default option of sharding the saved dataset was used, the element 192 order of the saved dataset will be preserved when loading it. 193 194 The `reader_func` argument can be used to specify a custom order in which 195 elements should be loaded from the individual shards. The `reader_func` is 196 expected to take a single argument -- a dataset of datasets, each containing 197 elements of one of the shards -- and return a dataset of elements. For 198 example, the order of shards can be shuffled when loading them as follows: 199 200 ```python 201 def custom_reader_func(datasets): 202 datasets = datasets.shuffle(NUM_SHARDS) 203 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 204 205 dataset = tf.data.experimental.load( 206 path="/path/to/data", ..., reader_func=custom_reader_func) 207 ``` 208 209 Args: 210 path: Required. A path pointing to a previously saved dataset. 211 element_spec: Optional. A nested structure of `tf.TypeSpec` objects matching 212 the structure of an element of the saved dataset and specifying the type 213 of individual element components. If not provided, the nested structure of 214 `tf.TypeSpec` saved with the saved dataset is used. 215 compression: Optional. The algorithm to use to decompress the data when 216 reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 217 reader_func: Optional. A function to control how to read data from shards. 218 If present, the function will be traced and executed as graph computation. 219 220 Returns: 221 A `tf.data.Dataset` instance. 222 223 Raises: 224 FileNotFoundError: If `element_spec` is not specified and the saved nested 225 structure of `tf.TypeSpec` can not be located with the saved dataset. 226 """ 227 228 return _LoadDataset( 229 path=path, 230 element_spec=element_spec, 231 compression=compression, 232 reader_func=reader_func) 233