1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_ 17 #define TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/lib/core/status.h" 23 24 namespace tensorflow { 25 namespace data { 26 27 // Reads dataset elements from the checkpoint reader using the given key prefix. 28 Status ReadElementsFromCheckpoint(IteratorContext* ctx, 29 IteratorStateReader* reader, 30 StringPiece key_prefix, 31 std::vector<std::vector<Tensor>>* elements); 32 33 // Writes dataset elements to the checkpoint writer using the given key prefix. 34 // The elements can be read back by passing the same key prefix to 35 // ReadElementsFromCheckpoint. Only one list of elements can be written under 36 // the same key_prefix. 37 Status WriteElementsToCheckpoint( 38 IteratorStateWriter* writer, StringPiece key_prefix, 39 const std::vector<std::vector<Tensor>>& elements); 40 41 // Helper class for reading data from a vector of VariantTensorData objects. 42 class VariantTensorDataReader : public IteratorStateReader { 43 public: 44 explicit VariantTensorDataReader( 45 const std::vector<const VariantTensorData*>& data); 46 47 bool Contains(StringPiece key) const override; 48 bool Contains(StringPiece name, StringPiece key) const override; 49 50 Status ReadScalar(StringPiece key, int64* val) const override; 51 Status ReadScalar(StringPiece name, StringPiece key, 52 int64* val) const override; 53 Status ReadScalar(StringPiece key, tstring* val) const override; 54 Status ReadScalar(StringPiece name, StringPiece key, 55 tstring* val) const override; 56 Status ReadTensor(StringPiece key, Tensor* val) const override; 57 Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, 58 Tensor* val) const override; 59 Status ReadTensor(StringPiece name, StringPiece key, 60 Tensor* val) const override; 61 Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, 62 StringPiece key, Tensor* val) const override; 63 64 private: 65 template <typename T> 66 Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const; 67 Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name, 68 StringPiece key, Tensor* val) const; 69 Status ReadDatasetInternal(FunctionLibraryRuntime* flr, StringPiece name, 70 StringPiece key, Tensor* val) const; 71 72 std::map<string, std::map<string, size_t>> map_; 73 std::map<string, const VariantTensorData*> data_; // Not owned. 74 }; 75 76 // Helper class used to build a list of VariantTensorData objects, one for each 77 // iterator which is determined from the key supplied from the Write* calls. 78 // Sample usage: 79 // VariantTensorDataWriter writer; 80 // writer.WriteScalar(full_name("buffer_size"), buffer_.size()); 81 // writer.WriteScalar(full_name("num_threads"), threadpool_.size()); 82 // .... 83 // std::vector<std::unique_ptr<VariantTensorData>> variants; 84 // writer.ReleaseData(&variants); 85 // Now the VariantTensorData objects can be used to serialize. 86 class VariantTensorDataWriter : public IteratorStateWriter { 87 public: 88 Status WriteScalar(StringPiece key, const int64_t val) override; 89 Status WriteScalar(StringPiece name, StringPiece key, 90 const int64_t val) override; 91 92 Status WriteScalar(StringPiece key, const tstring& val) override; 93 Status WriteScalar(StringPiece name, StringPiece key, 94 const tstring& val) override; 95 96 Status WriteTensor(StringPiece key, const Tensor& val) override; 97 Status WriteTensor(StringPiece name, StringPiece key, 98 const Tensor& val) override; 99 100 // Releases the built VariantTensorData's to `variants`. Clears out all 101 // class state. 102 void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants); 103 104 // Obtains a read-only version of the VariantTensorData's built. 105 void GetData(std::vector<const VariantTensorData*>* variants); 106 107 private: 108 void MaybeFlush(); 109 void Reset(); 110 111 template <typename T> 112 Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val); 113 Status WriteTensorInternal(StringPiece name, StringPiece key, 114 const Tensor& val); 115 Status WriteDatasetInternal(StringPiece name, StringPiece key, 116 const DatasetBase* dataset); 117 118 bool is_flushed_ = false; 119 std::map<string, std::unique_ptr<VariantTensorData>> data_; 120 std::map<string, std::vector<string>> keys_; 121 }; 122 123 // Returns a GraphDef representation of the given dataset. 124 Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, 125 SerializationContext&& serialization_ctx, 126 GraphDef* graph_def); 127 128 // Returns a GraphDef representation of the given dataset using the minimal 129 // serialization parameters (i.e. ignoring external state, not serializing 130 // data tensors, not failing if there are datasets which do not have AsGraphDef 131 // implemented). Sets the `dataset_node` parameter to the dataset's 132 // node name in the resulting GraphDef. 133 Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input, 134 std::vector<std::pair<string, Tensor>>* input_list, 135 GraphDef* result, string* dataset_node); 136 137 } // namespace data 138 } // namespace tensorflow 139 140 #endif // TENSORFLOW_CORE_KERNELS_SERIALIZATION_UTILS_H_ 141