• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
16 //
17 // Usage:
18 //  From command line:
19 //   bazel run third_party/tensorflow/lite/experimental/writer:writer
20 //     -- foo.tflite foo.out.tflite
21 //
22 // From C++
23 //   std::unique_ptr<Interpreter> interpreter;
24 //   // Build Interpreter however
25 //   // ... <omitted>
26 //   InterpreterWriter(interpreter.get()).Write("output.tflite");
27 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
28 #define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
29 #include <iostream>
30 #include <unordered_map>
31 #include "tensorflow/lite/builtin_op_data.h"
32 #include "tensorflow/lite/context_util.h"
33 #include "tensorflow/lite/experimental/writer/enum_mapping.h"
34 #include "tensorflow/lite/interpreter.h"
35 #include "tensorflow/lite/schema/reflection/schema_generated.h"
36 #include "tensorflow/lite/version.h"
37 
38 namespace tflite {
39 
40 // Handles writing TensorFlow Lite running interpreter to a serialized TF lite
41 // file format.
42 class InterpreterWriter {
43  public:
44   typedef flatbuffers::Offset<Operator> (*CustomWriter)(
45       flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
46       int node_index,
47       flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
48       CustomOptionsFormat* custom_options_format);
49 
50   // Construct an interpreter writer for the specified `interpreter`. Then,
51   // a uses .Write() or .GetBuffer(...)  to extract the data.
InterpreterWriter(Interpreter * interpreter)52   explicit InterpreterWriter(Interpreter* interpreter)
53       : interpreter_(interpreter) {
54     buffers_.push_back(std::make_pair(nullptr, 0));
55   }
56 
57   // Get a buffer and size of a serialized flatbuffer.
58   TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
59   // Write the serialized flatbuffer to the prescribed `filename`.
60   TfLiteStatus Write(const std::string& filename);
61   // Registers a custom writer for a custom op. The customization allows the
62   // caller to change the custom data.
63   TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
64                                     CustomWriter custom_writer);
65   // Tensors that are unused and shouldn't be written.
SetUnusedTensors(const std::set<int> & unused_tensors)66   void SetUnusedTensors(const std::set<int>& unused_tensors) {
67     unused_tensors_ = unused_tensors;
68   }
69 
70  private:
71   template <class T>
72   using Offset = flatbuffers::Offset<T>;
73   template <class T_OUTPUT, class T_INPUT>
74   Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
75       flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
76   Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
77       flatbuffers::FlatBufferBuilder* fbb);
78   Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
79       flatbuffers::FlatBufferBuilder* fbb);
80   Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
81       flatbuffers::FlatBufferBuilder* fbb);
82   Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
83       flatbuffers::FlatBufferBuilder* fbb);
84 
85   template <class T>
86   std::vector<int> RemapTensorIndicesToWritten(const T& input);
87 
GetOpCodeForBuiltin(int builtin_op_index)88   int GetOpCodeForBuiltin(int builtin_op_index) {
89     // auto it = builtin_op_to_opcode_.find(builtin_op_index);
90     std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
91         builtin_op_to_opcode_.insert(
92             std::make_pair(builtin_op_index, opcodes_.size()));
93     if (result.second) {
94       opcodes_.push_back({builtin_op_index, ""});
95     }
96     return result.first->second;
97   }
98 
GetOpCodeForCustom(const std::string & custom_name)99   int GetOpCodeForCustom(const std::string& custom_name) {
100     std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
101         custom_op_to_opcode_.insert(
102             std::make_pair(custom_name, opcodes_.size()));
103     if (result.second) {
104       opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
105     }
106     return result.first->second;
107   }
108 
109   // The interpreter we are writing
110   Interpreter* interpreter_;
111   // Keep track of byte buffers
112   std::vector<std::pair<const uint8_t*, size_t>> buffers_;
113   // List of op codes and mappings from builtin or custom op to opcode
114   struct OpCode {
115     int builtin;
116     std::string custom;
117   };
118   std::set<int> unused_tensors_;
119   // For every tensor index in the interpreter, the index in the written.
120   // This is different due to temporary and unused tensors not being written.
121   std::vector<int> tensor_to_written_tensor_;
122   // List of used opcodes
123   std::vector<OpCode> opcodes_;
124   std::unordered_map<int, int> builtin_op_to_opcode_;
125   std::unordered_map<std::string, int> custom_op_to_opcode_;
126   std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
127 };
128 
129 }  // namespace tflite
130 
131 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
132