1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter. 16 // 17 // Usage: 18 // From command line: 19 // bazel run third_party/tensorflow/lite/experimental/writer:writer 20 // -- foo.tflite foo.out.tflite 21 // 22 // From C++ 23 // std::unique_ptr<Interpreter> interpreter; 24 // // Build Interpreter however 25 // // ... <omitted> 26 // InterpreterWriter(interpreter.get()).Write("output.tflite"); 27 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ 28 #define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ 29 #include <iostream> 30 #include <unordered_map> 31 #include "tensorflow/lite/builtin_op_data.h" 32 #include "tensorflow/lite/context_util.h" 33 #include "tensorflow/lite/experimental/writer/enum_mapping.h" 34 #include "tensorflow/lite/interpreter.h" 35 #include "tensorflow/lite/schema/reflection/schema_generated.h" 36 #include "tensorflow/lite/version.h" 37 38 namespace tflite { 39 40 // Handles writing TensorFlow Lite running interpreter to a serialized TF lite 41 // file format. 42 class InterpreterWriter { 43 public: 44 typedef flatbuffers::Offset<Operator> (*CustomWriter)( 45 flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter, 46 int node_index, 47 flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options, 48 CustomOptionsFormat* custom_options_format); 49 50 // Construct an interpreter writer for the specified `interpreter`. Then, 51 // a uses .Write() or .GetBuffer(...) to extract the data. InterpreterWriter(Interpreter * interpreter)52 explicit InterpreterWriter(Interpreter* interpreter) 53 : interpreter_(interpreter) { 54 buffers_.push_back(std::make_pair(nullptr, 0)); 55 } 56 57 // Get a buffer and size of a serialized flatbuffer. 58 TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size); 59 // Write the serialized flatbuffer to the prescribed `filename`. 60 TfLiteStatus Write(const std::string& filename); 61 // Registers a custom writer for a custom op. The customization allows the 62 // caller to change the custom data. 63 TfLiteStatus RegisterCustomWriter(const std::string& custom_name, 64 CustomWriter custom_writer); 65 // Tensors that are unused and shouldn't be written. SetUnusedTensors(const std::set<int> & unused_tensors)66 void SetUnusedTensors(const std::set<int>& unused_tensors) { 67 unused_tensors_ = unused_tensors; 68 } 69 70 private: 71 template <class T> 72 using Offset = flatbuffers::Offset<T>; 73 template <class T_OUTPUT, class T_INPUT> 74 Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector( 75 flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); 76 Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors( 77 flatbuffers::FlatBufferBuilder* fbb); 78 Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators( 79 flatbuffers::FlatBufferBuilder* fbb); 80 Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable( 81 flatbuffers::FlatBufferBuilder* fbb); 82 Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers( 83 flatbuffers::FlatBufferBuilder* fbb); 84 85 template <class T> 86 std::vector<int> RemapTensorIndicesToWritten(const T& input); 87 GetOpCodeForBuiltin(int builtin_op_index)88 int GetOpCodeForBuiltin(int builtin_op_index) { 89 // auto it = builtin_op_to_opcode_.find(builtin_op_index); 90 std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result = 91 builtin_op_to_opcode_.insert( 92 std::make_pair(builtin_op_index, opcodes_.size())); 93 if (result.second) { 94 opcodes_.push_back({builtin_op_index, ""}); 95 } 96 return result.first->second; 97 } 98 GetOpCodeForCustom(const std::string & custom_name)99 int GetOpCodeForCustom(const std::string& custom_name) { 100 std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result = 101 custom_op_to_opcode_.insert( 102 std::make_pair(custom_name, opcodes_.size())); 103 if (result.second) { 104 opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); 105 } 106 return result.first->second; 107 } 108 109 // The interpreter we are writing 110 Interpreter* interpreter_; 111 // Keep track of byte buffers 112 std::vector<std::pair<const uint8_t*, size_t>> buffers_; 113 // List of op codes and mappings from builtin or custom op to opcode 114 struct OpCode { 115 int builtin; 116 std::string custom; 117 }; 118 std::set<int> unused_tensors_; 119 // For every tensor index in the interpreter, the index in the written. 120 // This is different due to temporary and unused tensors not being written. 121 std::vector<int> tensor_to_written_tensor_; 122 // List of used opcodes 123 std::vector<OpCode> opcodes_; 124 std::unordered_map<int, int> builtin_op_to_opcode_; 125 std::unordered_map<std::string, int> custom_op_to_opcode_; 126 std::unordered_map<std::string, CustomWriter> custom_op_to_writer_; 127 }; 128 129 } // namespace tflite 130 131 #endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ 132