1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Library to write a flatbuffer of a currently loaded TFLite model/subgraph. 16 17 #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ 18 #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ 19 #include <iostream> 20 #include <unordered_map> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "tensorflow/lite/builtin_op_data.h" 24 #include "tensorflow/lite/c/common.h" 25 #include "tensorflow/lite/context_util.h" 26 #include "tensorflow/lite/core/subgraph.h" 27 #include "tensorflow/lite/interpreter.h" 28 #include "tensorflow/lite/schema/reflection/schema_generated.h" 29 #include "tensorflow/lite/tools/serialization/enum_mapping.h" 30 #include "tensorflow/lite/version.h" 31 32 namespace tflite { 33 34 struct OpCode { 35 int builtin; 36 std::string custom; 37 }; 38 39 // Handles writing a full TFLite model (with 1 or more subgraphs) to a 40 // serialized TF lite file format. 41 // TODO(b/174708523): Support custom I/O or unused tensors later. 42 class ModelWriter { 43 public: 44 // Construct a writer for the specified `interpreter`. Then, use 45 // .Write() or .GetBuffer(...) to extract the data. ModelWriter(Interpreter * interpreter)46 explicit ModelWriter(Interpreter* interpreter) : interpreter_(interpreter) { 47 buffers_.push_back(std::make_pair(nullptr, 0)); 48 } 49 50 // Get a buffer and size of a serialized flatbuffer. 51 TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size); 52 // Write the serialized flatbuffer to the prescribed `filename`. 53 TfLiteStatus Write(const std::string& filename); 54 55 private: 56 template <class T> 57 using Offset = flatbuffers::Offset<T>; 58 Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable( 59 flatbuffers::FlatBufferBuilder* fbb); 60 Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers( 61 flatbuffers::FlatBufferBuilder* fbb); 62 63 // ModelWriter does not take ownership of this object. 64 Interpreter* const interpreter_; 65 66 // This data corresponds to the overall model (rather than individual 67 // subgraphs), so we define common fields. Keep track of byte buffers 68 std::vector<std::pair<const uint8_t*, size_t>> buffers_; 69 // List of used opcodes 70 std::vector<OpCode> opcodes_; 71 absl::flat_hash_map<int, int> builtin_op_to_opcode_; 72 }; 73 74 // Handles writing TensorFlow Lite running subgraph to a serialized TF lite 75 // file format. 76 // TODO(b/174708523): Reconcile into ModelWriter? 77 class SubgraphWriter { 78 public: 79 friend class ModelWriter; 80 81 typedef flatbuffers::Offset<Operator> (*CustomWriter)( 82 flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index, 83 flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options, 84 CustomOptionsFormat* custom_options_format); 85 86 // Construct a subgraph writer for the specified `subgraph`. Then, use 87 // .Write() or .GetBuffer(...) to extract the data. SubgraphWriter(Subgraph * subgraph)88 explicit SubgraphWriter(Subgraph* subgraph) 89 : subgraph_(subgraph), 90 inputs_(subgraph->inputs()), 91 outputs_(subgraph->outputs()), 92 execution_plan_(subgraph->execution_plan()) { 93 buffers_ = &buffers_data_; 94 opcodes_ = &opcodes_data_; 95 builtin_op_to_opcode_ = &builtin_op_to_opcode_data_; 96 buffers_->push_back(std::make_pair(nullptr, 0)); 97 } 98 99 // Get a buffer and size of a serialized flatbuffer. 100 TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size); 101 // Write the serialized flatbuffer to the prescribed `filename`. 102 TfLiteStatus Write(const std::string& filename); 103 // Registers a custom writer for a custom op. The customization allows the 104 // caller to change the custom data. 105 TfLiteStatus RegisterCustomWriter(const std::string& custom_name, 106 CustomWriter custom_writer); 107 // Tensors that are unused and shouldn't be written. SetUnusedTensors(const std::set<int> & unused_tensors)108 void SetUnusedTensors(const std::set<int>& unused_tensors) { 109 unused_tensors_ = unused_tensors; 110 } 111 // Sets custom inputs, outputs, and execution_plan so that a portion of the 112 // subgraph is written to the buffer instead of the whole subgraph. 113 TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs, 114 const std::vector<int>& outputs, 115 const std::vector<int>& execution_plan); 116 117 private: 118 // Used by ModelWriter. SubgraphWriter(Subgraph * subgraph,std::vector<std::pair<const uint8_t *,size_t>> * external_buffers,std::vector<OpCode> * external_opcodes,absl::flat_hash_map<int,int> * external_builtin_op_to_opcode)119 explicit SubgraphWriter( 120 Subgraph* subgraph, 121 std::vector<std::pair<const uint8_t*, size_t>>* external_buffers, 122 std::vector<OpCode>* external_opcodes, 123 absl::flat_hash_map<int, int>* external_builtin_op_to_opcode) 124 : subgraph_(subgraph), 125 inputs_(subgraph->inputs()), 126 outputs_(subgraph->outputs()), 127 execution_plan_(subgraph->execution_plan()) { 128 buffers_ = external_buffers; 129 opcodes_ = external_opcodes; 130 builtin_op_to_opcode_ = external_builtin_op_to_opcode; 131 buffers_->push_back(std::make_pair(nullptr, 0)); 132 } 133 134 // Used by ModelWriter to populate data specific to this subgraph. 135 // Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_, 136 // etc. & populated in the Flatbuffer by ModelWriter. 137 flatbuffers::Offset<SubGraph> PopulateAndGetOffset( 138 flatbuffers::FlatBufferBuilder* builder); 139 140 template <class T> 141 using Offset = flatbuffers::Offset<T>; 142 template <class T_OUTPUT, class T_INPUT> 143 Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector( 144 flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); 145 Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors( 146 flatbuffers::FlatBufferBuilder* fbb); 147 Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators( 148 flatbuffers::FlatBufferBuilder* fbb); 149 Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable( 150 flatbuffers::FlatBufferBuilder* fbb); 151 Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers( 152 flatbuffers::FlatBufferBuilder* fbb); 153 154 template <class T> 155 std::vector<int> RemapTensorIndicesToWritten(const T& input); 156 157 // Checks if given `input`, `output`, and `execution_plan` represents a valid 158 // model within the Subgraph. 159 TfLiteStatus CheckInputOutput(const std::vector<int>& inputs, 160 const std::vector<int>& outputs, 161 const std::vector<int>& execution_plan); 162 GetOpCodeForBuiltin(int builtin_op_index)163 int GetOpCodeForBuiltin(int builtin_op_index) { 164 // auto it = builtin_op_to_opcode_.find(builtin_op_index); 165 std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result = 166 builtin_op_to_opcode_->insert( 167 std::make_pair(builtin_op_index, opcodes_->size())); 168 if (result.second) { 169 opcodes_->push_back({builtin_op_index, ""}); 170 } 171 return result.first->second; 172 } 173 GetOpCodeForCustom(const std::string & custom_name)174 int GetOpCodeForCustom(const std::string& custom_name) { 175 std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result = 176 custom_op_to_opcode_.insert( 177 std::make_pair(custom_name, opcodes_->size())); 178 if (result.second) { 179 opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name}); 180 } 181 return result.first->second; 182 } 183 184 // The subgraph we are writing 185 Subgraph* subgraph_; 186 // Input tensor indices to be written. 187 std::vector<int> inputs_; 188 // Output tensor indices to be written. 189 std::vector<int> outputs_; 190 // Order of nodes to be written. 191 std::vector<int> execution_plan_; 192 // List of op codes and mappings from builtin or custom op to opcode 193 std::set<int> unused_tensors_; 194 // For every tensor index in the subgraph, the index in the written. 195 // This is different due to temporary and unused tensors not being written. 196 std::vector<int> tensor_to_written_tensor_; 197 std::unordered_map<std::string, int> custom_op_to_opcode_; 198 std::unordered_map<std::string, CustomWriter> custom_op_to_writer_; 199 200 // We use pointers for these, since they may be provided by ModelWriter. 201 // Keep track of byte buffers 202 std::vector<std::pair<const uint8_t*, size_t>>* buffers_; 203 // List of used opcodes 204 std::vector<OpCode>* opcodes_; 205 absl::flat_hash_map<int, int>* builtin_op_to_opcode_; 206 207 // These are used if SubgraphWriter is being used directly. 208 std::vector<std::pair<const uint8_t*, size_t>> buffers_data_; 209 // List of used opcodes 210 std::vector<OpCode> opcodes_data_; 211 absl::flat_hash_map<int, int> builtin_op_to_opcode_data_; 212 }; 213 214 } // namespace tflite 215 216 #endif // TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ 217