1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/writer/writer_lib.h"
16 #include <cstdlib>
17 #include <cstring>
18 #include <unordered_map>
19 #include "tensorflow/lite/builtin_op_data.h"
20 #include "tensorflow/lite/context_util.h"
21 #include "tensorflow/lite/experimental/writer/enum_mapping.h"
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/schema/reflection/schema_generated.h"
24 #include "tensorflow/lite/version.h"
25
26 namespace tflite {
27 template <class T>
28 using Offset = flatbuffers::Offset<T>;
29 template <class T>
30 using Vector = flatbuffers::Vector<T>;
31 using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
32
CreateBuiltinUnion(FlatBufferBuilder * fbb,enum BuiltinOperator op,void * builtin_op_data)33 std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
34 FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
35 switch (op) {
36 #include "tensorflow/lite/experimental/writer/option_writer_generated.h"
37 }
38 return std::make_pair(BuiltinOptions_NONE, Offset<void>());
39 }
40
41 template <class T_OUTPUT, class T_INPUT>
ExportVector(FlatBufferBuilder * fbb,const T_INPUT & v)42 Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
43 const T_INPUT& v) {
44 std::vector<T_OUTPUT> inputs(v.begin(), v.end());
45 return fbb->template CreateVector<T_OUTPUT>(inputs);
46 }
47
ExportOperators(FlatBufferBuilder * fbb)48 Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
49 FlatBufferBuilder* fbb) {
50 std::vector<Offset<Operator>> operators;
51
52 std::vector<int> operator_to_opcode;
53 // TODO(aselle): Augment this once we put execution plan in schema.
54 operator_to_opcode.resize(interpreter_->nodes_size(), -1);
55 for (int op_index : interpreter_->execution_plan()) {
56 const auto* node_and_registration =
57 interpreter_->node_and_registration(op_index);
58 const TfLiteRegistration* registration = &node_and_registration->second;
59 if (!registration->custom_name) {
60 operator_to_opcode[op_index] =
61 GetOpCodeForBuiltin(registration->builtin_code);
62 } else {
63 operator_to_opcode[op_index] =
64 GetOpCodeForCustom(registration->custom_name);
65 }
66 }
67 // second pass serialize operators
68 for (int op_index : interpreter_->execution_plan()) {
69 const auto* node_and_registration =
70 interpreter_->node_and_registration(op_index);
71 const TfLiteNode& node = node_and_registration->first;
72 const TfLiteRegistration& registration = node_and_registration->second;
73 Offset<void> builtin_options;
74 BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
75 // Custom data
76 // TODO(aselle): Custom options format is not known by default. Just assume
77 // for now.
78 auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
79 Offset<Vector<uint8_t>> custom_options = 0;
80
81 if (!registration.custom_name) {
82 // builtin
83 auto builtin_options_and_type = CreateBuiltinUnion(
84 fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
85 node.builtin_data);
86 builtin_options = builtin_options_and_type.second;
87 builtin_options_type = builtin_options_and_type.first;
88 } else {
89 auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
90 if (custom_writer != custom_op_to_writer_.end() &&
91 custom_writer->second) {
92 // delegate to custom writer if it exists
93 custom_writer->second(fbb, interpreter_, op_index, &custom_options,
94 &custom_options_format);
95 } else {
96 // use the custom data as fact
97 custom_options = fbb->CreateVector(
98 reinterpret_cast<const uint8_t*>(node.custom_initial_data),
99 node.custom_initial_data_size);
100 }
101 }
102
103 int opcode_index = operator_to_opcode[op_index];
104 std::vector<int> written_inputs =
105 RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
106 std::vector<int> written_outputs =
107 RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
108 auto inputs = ExportVector<int32_t>(fbb, written_inputs);
109 auto outputs = ExportVector<int32_t>(fbb, written_outputs);
110 operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
111 builtin_options_type, builtin_options,
112 custom_options, custom_options_format));
113 }
114
115 return fbb->template CreateVector<Offset<Operator>>(operators);
116 }
117
ExportTensors(FlatBufferBuilder * fbb)118 Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
119 FlatBufferBuilder* fbb) {
120 // Initialized to -1.
121 // A value of -1 means this tensor will not be exported.
122 tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
123
124 std::vector<Offset<Tensor>> tensors;
125
126 // Make a map from tensor index to whether the tensor is a temporary.
127 std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
128 for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
129 const auto* node_and_registration =
130 interpreter_->node_and_registration(op_index);
131 for (auto tensor_index :
132 TfLiteIntArrayView(node_and_registration->first.temporaries))
133 tensor_is_temporary[tensor_index] = true;
134 }
135
136 // Now we need to remap all used tensor indices
137 int curr_output_index = 0;
138 for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
139 tensor_index++) {
140 // Temporary tensors and unused tensors will not be written.
141 if (!tensor_is_temporary[tensor_index] &&
142 unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
143 tensor_to_written_tensor_[tensor_index] = curr_output_index++;
144 }
145 }
146
147 for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
148 ++tensor_index) {
149 // Tensor not exported.
150 if (tensor_to_written_tensor_[tensor_index] == -1) continue;
151
152 if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
153 // We only need to convert non temporaries
154 if (tensor->allocation_type != kTfLiteArenaRw &&
155 tensor->allocation_type != kTfLiteMmapRo &&
156 tensor->allocation_type != kTfLiteArenaRwPersistent)
157 continue;
158 // Allocate a buffer index
159 int buffer_index = 0; // This is null
160 if (tensor->allocation_type == kTfLiteMmapRo) {
161 buffer_index = buffers_.size();
162 buffers_.push_back(std::make_pair(
163 reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
164 }
165 // Primitive type.
166 TensorType type = TfLiteTypeToSchemaType(tensor->type);
167 // Handle quantization
168 const Offset<Vector<float>> null_array;
169 Offset<Vector<float>> scale_array;
170 Offset<Vector<int64_t>> zero_point_array;
171 if (tensor->params.scale != 0.f) {
172 // We have quantization, make a single arugment array (multi channel
173 // quant needs updating here).
174 scale_array = fbb->CreateVector<float>({tensor->params.scale});
175 zero_point_array =
176 fbb->CreateVector<int64_t>({tensor->params.zero_point});
177 }
178 Offset<QuantizationParameters> quantization_params =
179 CreateQuantizationParameters(*fbb, null_array, null_array,
180 scale_array, zero_point_array);
181 // Shape
182 TfLiteIntArrayView shape_view(tensor->dims);
183 std::vector<int> shape =
184 std::vector<int>(shape_view.begin(), shape_view.end());
185
186 tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
187 type, buffer_index,
188 fbb->CreateString(tensor->name),
189 quantization_params, tensor->is_variable));
190 }
191 }
192 return fbb->template CreateVector<Offset<Tensor>>(tensors);
193 }
194
ExportBuffers(FlatBufferBuilder * fbb)195 Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
196 FlatBufferBuilder* fbb) {
197 std::vector<Offset<Buffer>> buffer_vector;
198 for (auto buffer : buffers_) {
199 auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
200 buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
201 }
202 return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
203 }
204
CreateOpCodeTable(FlatBufferBuilder * fbb)205 Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
206 FlatBufferBuilder* fbb) {
207 std::vector<Offset<OperatorCode>> codes;
208 for (auto it : opcodes_) {
209 const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
210 codes.push_back(CreateOperatorCodeDirect(
211 *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
212 }
213 return fbb->template CreateVector<Offset<OperatorCode>>(codes);
214 }
215
216 template <class T>
RemapTensorIndicesToWritten(const T & input)217 std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
218 const T& input) {
219 std::vector<int> output;
220 output.reserve(input.size());
221 for (int x : input) {
222 // Special value representing an optional tensor which is not present.
223 if (x == -1) {
224 output.push_back(x);
225 continue;
226 }
227 if (tensor_to_written_tensor_[x] != -1) {
228 output.push_back(tensor_to_written_tensor_[x]);
229 }
230 }
231 return output;
232 }
233
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)234 TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
235 size_t* size) {
236 if (!out || !size) return kTfLiteError;
237 FlatBufferBuilder builder(/*initial_size=*/10240);
238
239 std::vector<Offset<SubGraph>> subgraphs_as_vector;
240 { // subgraph specific stuff
241 auto tensors = ExportTensors(&builder);
242 std::vector<int> written_inputs =
243 RemapTensorIndicesToWritten(interpreter_->inputs());
244 std::vector<int> written_outputs =
245 RemapTensorIndicesToWritten(interpreter_->outputs());
246 auto inputs = ExportVector<int32_t>(&builder, written_inputs);
247 auto outputs = ExportVector<int32_t>(&builder, written_outputs);
248
249 auto ops = ExportOperators(&builder);
250 subgraphs_as_vector.push_back(
251 CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
252 }
253 Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
254
255 auto description = builder.CreateString("Exported from Interpreter.");
256
257 auto op_codes = CreateOpCodeTable(&builder);
258 auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
259 builder.CreateVector(subgraphs_as_vector),
260 description, buffers);
261 ::tflite::FinishModelBuffer(builder, model);
262 const uint8_t* buffer = builder.GetBufferPointer();
263 *size = builder.GetSize();
264 (*out).reset(new uint8_t[*size]);
265 memcpy(out->get(), buffer, *size);
266 return kTfLiteOk;
267 }
268
Write(const std::string & filename)269 TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
270 std::unique_ptr<uint8_t[]> buffer;
271 size_t size;
272 TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
273
274 FILE* fp = fopen(filename.c_str(), "wb");
275 if (!fp) return kTfLiteError;
276
277 if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
278 if (fclose(fp)) return kTfLiteError;
279
280 return kTfLiteOk;
281 }
282
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)283 TfLiteStatus InterpreterWriter::RegisterCustomWriter(
284 const std::string& custom_name, CustomWriter custom_writer) {
285 if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
286 return kTfLiteError;
287 }
288 custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
289 return kTfLiteOk;
290 }
291
292 } // namespace tflite
293