• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Library to write a flatbuffer of a currently loaded TFLite model/subgraph.
16 
17 #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
18 #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
19 #include <iostream>
20 #include <unordered_map>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/schema/reflection/schema_generated.h"
29 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
30 #include "tensorflow/lite/version.h"
31 
32 namespace tflite {
33 
34 struct OpCode {
35   int builtin;
36   std::string custom;
37 };
38 
39 // Handles writing a full TFLite model (with 1 or more subgraphs) to a
40 // serialized TF lite file format.
41 // TODO(b/174708523): Support custom I/O or unused tensors later.
42 class ModelWriter {
43  public:
44   // Construct a writer for the specified `interpreter`. Then, use
45   // .Write() or .GetBuffer(...) to extract the data.
ModelWriter(Interpreter * interpreter)46   explicit ModelWriter(Interpreter* interpreter) : interpreter_(interpreter) {
47     buffers_.push_back(std::make_pair(nullptr, 0));
48   }
49 
50   // Get a buffer and size of a serialized flatbuffer.
51   TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
52   // Write the serialized flatbuffer to the prescribed `filename`.
53   TfLiteStatus Write(const std::string& filename);
54 
55  private:
56   template <class T>
57   using Offset = flatbuffers::Offset<T>;
58   Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
59       flatbuffers::FlatBufferBuilder* fbb);
60   Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
61       flatbuffers::FlatBufferBuilder* fbb);
62 
63   // ModelWriter does not take ownership of this object.
64   Interpreter* const interpreter_;
65 
66   // This data corresponds to the overall model (rather than individual
67   // subgraphs), so we define common fields. Keep track of byte buffers
68   std::vector<std::pair<const uint8_t*, size_t>> buffers_;
69   // List of used opcodes
70   std::vector<OpCode> opcodes_;
71   absl::flat_hash_map<int, int> builtin_op_to_opcode_;
72 };
73 
74 // Handles writing TensorFlow Lite running subgraph to a serialized TF lite
75 // file format.
76 // TODO(b/174708523): Reconcile into ModelWriter?
77 class SubgraphWriter {
78  public:
79   friend class ModelWriter;
80 
81   typedef flatbuffers::Offset<Operator> (*CustomWriter)(
82       flatbuffers::FlatBufferBuilder* fbb, Subgraph* subgraph, int node_index,
83       flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
84       CustomOptionsFormat* custom_options_format);
85 
86   // Construct a subgraph writer for the specified `subgraph`. Then, use
87   // .Write() or .GetBuffer(...) to extract the data.
SubgraphWriter(Subgraph * subgraph)88   explicit SubgraphWriter(Subgraph* subgraph)
89       : subgraph_(subgraph),
90         inputs_(subgraph->inputs()),
91         outputs_(subgraph->outputs()),
92         execution_plan_(subgraph->execution_plan()) {
93     buffers_ = &buffers_data_;
94     opcodes_ = &opcodes_data_;
95     builtin_op_to_opcode_ = &builtin_op_to_opcode_data_;
96     buffers_->push_back(std::make_pair(nullptr, 0));
97   }
98 
99   // Get a buffer and size of a serialized flatbuffer.
100   TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
101   // Write the serialized flatbuffer to the prescribed `filename`.
102   TfLiteStatus Write(const std::string& filename);
103   // Registers a custom writer for a custom op. The customization allows the
104   // caller to change the custom data.
105   TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
106                                     CustomWriter custom_writer);
107   // Tensors that are unused and shouldn't be written.
SetUnusedTensors(const std::set<int> & unused_tensors)108   void SetUnusedTensors(const std::set<int>& unused_tensors) {
109     unused_tensors_ = unused_tensors;
110   }
111   // Sets custom inputs, outputs, and execution_plan so that a portion of the
112   // subgraph is written to the buffer instead of the whole subgraph.
113   TfLiteStatus SetCustomInputOutput(const std::vector<int>& inputs,
114                                     const std::vector<int>& outputs,
115                                     const std::vector<int>& execution_plan);
116 
117  private:
118   // Used by ModelWriter.
SubgraphWriter(Subgraph * subgraph,std::vector<std::pair<const uint8_t *,size_t>> * external_buffers,std::vector<OpCode> * external_opcodes,absl::flat_hash_map<int,int> * external_builtin_op_to_opcode)119   explicit SubgraphWriter(
120       Subgraph* subgraph,
121       std::vector<std::pair<const uint8_t*, size_t>>* external_buffers,
122       std::vector<OpCode>* external_opcodes,
123       absl::flat_hash_map<int, int>* external_builtin_op_to_opcode)
124       : subgraph_(subgraph),
125         inputs_(subgraph->inputs()),
126         outputs_(subgraph->outputs()),
127         execution_plan_(subgraph->execution_plan()) {
128     buffers_ = external_buffers;
129     opcodes_ = external_opcodes;
130     builtin_op_to_opcode_ = external_builtin_op_to_opcode;
131     buffers_->push_back(std::make_pair(nullptr, 0));
132   }
133 
134   // Used by ModelWriter to populate data specific to this subgraph.
135   // Global stuff (like opcodes & buffers) is populated into buffers_, opcodes_,
136   // etc. & populated in the Flatbuffer by ModelWriter.
137   flatbuffers::Offset<SubGraph> PopulateAndGetOffset(
138       flatbuffers::FlatBufferBuilder* builder);
139 
140   template <class T>
141   using Offset = flatbuffers::Offset<T>;
142   template <class T_OUTPUT, class T_INPUT>
143   Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
144       flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
145   Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
146       flatbuffers::FlatBufferBuilder* fbb);
147   Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
148       flatbuffers::FlatBufferBuilder* fbb);
149   Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
150       flatbuffers::FlatBufferBuilder* fbb);
151   Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
152       flatbuffers::FlatBufferBuilder* fbb);
153 
154   template <class T>
155   std::vector<int> RemapTensorIndicesToWritten(const T& input);
156 
157   // Checks if given `input`, `output`, and `execution_plan` represents a valid
158   // model within the Subgraph.
159   TfLiteStatus CheckInputOutput(const std::vector<int>& inputs,
160                                 const std::vector<int>& outputs,
161                                 const std::vector<int>& execution_plan);
162 
GetOpCodeForBuiltin(int builtin_op_index)163   int GetOpCodeForBuiltin(int builtin_op_index) {
164     // auto it = builtin_op_to_opcode_.find(builtin_op_index);
165     std::pair<decltype(builtin_op_to_opcode_data_)::iterator, bool> result =
166         builtin_op_to_opcode_->insert(
167             std::make_pair(builtin_op_index, opcodes_->size()));
168     if (result.second) {
169       opcodes_->push_back({builtin_op_index, ""});
170     }
171     return result.first->second;
172   }
173 
GetOpCodeForCustom(const std::string & custom_name)174   int GetOpCodeForCustom(const std::string& custom_name) {
175     std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
176         custom_op_to_opcode_.insert(
177             std::make_pair(custom_name, opcodes_->size()));
178     if (result.second) {
179       opcodes_->push_back({BuiltinOperator_CUSTOM, custom_name});
180     }
181     return result.first->second;
182   }
183 
184   // The subgraph we are writing
185   Subgraph* subgraph_;
186   // Input tensor indices to be written.
187   std::vector<int> inputs_;
188   // Output tensor indices to be written.
189   std::vector<int> outputs_;
190   // Order of nodes to be written.
191   std::vector<int> execution_plan_;
192   // List of op codes and mappings from builtin or custom op to opcode
193   std::set<int> unused_tensors_;
194   // For every tensor index in the subgraph, the index in the written.
195   // This is different due to temporary and unused tensors not being written.
196   std::vector<int> tensor_to_written_tensor_;
197   std::unordered_map<std::string, int> custom_op_to_opcode_;
198   std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
199 
200   // We use pointers for these, since they may be provided by ModelWriter.
201   // Keep track of byte buffers
202   std::vector<std::pair<const uint8_t*, size_t>>* buffers_;
203   // List of used opcodes
204   std::vector<OpCode>* opcodes_;
205   absl::flat_hash_map<int, int>* builtin_op_to_opcode_;
206 
207   // These are used if SubgraphWriter is being used directly.
208   std::vector<std::pair<const uint8_t*, size_t>> buffers_data_;
209   // List of used opcodes
210   std::vector<OpCode> opcodes_data_;
211   absl::flat_hash_map<int, int> builtin_op_to_opcode_data_;
212 };
213 
214 }  // namespace tflite
215 
216 #endif  // TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_
217