1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for tf.data writers.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import convert 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.ops import gen_experimental_dataset_ops 26from tensorflow.python.util import deprecation 27from tensorflow.python.util.tf_export import tf_export 28 29 30@tf_export("data.experimental.TFRecordWriter") 31@deprecation.deprecated( 32 None, "To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save " 33 "and load the contents of a dataset, use `tf.data.experimental.save` " 34 "and `tf.data.experimental.load`") 35class TFRecordWriter(object): 36 """Writes a dataset to a TFRecord file. 37 38 The elements of the dataset must be scalar strings. To serialize dataset 39 elements as strings, you can use the `tf.io.serialize_tensor` function. 40 41 ```python 42 dataset = tf.data.Dataset.range(3) 43 dataset = dataset.map(tf.io.serialize_tensor) 44 writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord") 45 writer.write(dataset) 46 ``` 47 48 To read back the elements, use `TFRecordDataset`. 49 50 ```python 51 dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord") 52 dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64)) 53 ``` 54 55 To shard a `dataset` across multiple TFRecord files: 56 57 ```python 58 dataset = ... # dataset to be written 59 60 def reduce_func(key, dataset): 61 filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)]) 62 writer = tf.data.experimental.TFRecordWriter(filename) 63 writer.write(dataset.map(lambda _, x: x)) 64 return tf.data.Dataset.from_tensors(filename) 65 66 dataset = dataset.enumerate() 67 dataset = dataset.apply(tf.data.experimental.group_by_window( 68 lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max 69 )) 70 71 # Iterate through the dataset to trigger data writing. 72 for _ in dataset: 73 pass 74 ``` 75 """ 76 77 def __init__(self, filename, compression_type=None): 78 """Initializes a `TFRecordWriter`. 79 80 Args: 81 filename: a string path indicating where to write the TFRecord data. 82 compression_type: (Optional.) a string indicating what type of compression 83 to use when writing the file. See `tf.io.TFRecordCompressionType` for 84 what types of compression are available. Defaults to `None`. 85 """ 86 self._filename = ops.convert_to_tensor( 87 filename, dtypes.string, name="filename") 88 self._compression_type = convert.optional_param_to_tensor( 89 "compression_type", 90 compression_type, 91 argument_default="", 92 argument_dtype=dtypes.string) 93 94 def write(self, dataset): 95 """Writes a dataset to a TFRecord file. 96 97 An operation that writes the content of the specified dataset to the file 98 specified in the constructor. 99 100 If the file exists, it will be overwritten. 101 102 Args: 103 dataset: a `tf.data.Dataset` whose elements are to be written to a file 104 105 Returns: 106 In graph mode, this returns an operation which when executed performs the 107 write. In eager mode, the write is performed by the method itself and 108 there is no return value. 109 110 Raises 111 TypeError: if `dataset` is not a `tf.data.Dataset`. 112 TypeError: if the elements produced by the dataset are not scalar strings. 113 """ 114 if not isinstance(dataset, dataset_ops.DatasetV2): 115 raise TypeError("`dataset` must be a `tf.data.Dataset` object.") 116 if not dataset_ops.get_structure(dataset).is_compatible_with( 117 tensor_spec.TensorSpec([], dtypes.string)): 118 raise TypeError( 119 "`dataset` must produce scalar `DT_STRING` tensors whereas it " 120 "produces shape {0} and types {1}".format( 121 dataset_ops.get_legacy_output_shapes(dataset), 122 dataset_ops.get_legacy_output_types(dataset))) 123 # pylint: disable=protected-access 124 dataset = dataset._apply_debug_options() 125 return gen_experimental_dataset_ops.dataset_to_tf_record( 126 dataset._variant_tensor, self._filename, self._compression_type) 127