1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/serialization/writer_lib.h"
16
17 #include <cstdlib>
18 #include <cstring>
19 #include <unordered_map>
20 #include <unordered_set>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/schema/reflection/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_conversion_utils.h"
29 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
30 #include "tensorflow/lite/version.h"
31
32 namespace tflite {
33 namespace {
34
35 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<OpCode> * opcodes)36 CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
37 std::vector<OpCode>* opcodes) {
38 std::vector<flatbuffers::Offset<OperatorCode>> codes;
39 for (const auto& it : *opcodes) {
40 const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
41 codes.push_back(CreateOperatorCodeDirect(
42 *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
43 }
44 return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
45 }
46
47 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffersImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<std::pair<const uint8_t *,size_t>> * buffers)48 ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
49 std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
50 std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
51 for (auto buffer : *buffers) {
52 auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
53 buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
54 }
55 return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
56 }
57
WriteImpl(const std::string & filename,void * data,size_t size)58 TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
59 FILE* fp = fopen(filename.c_str(), "wb");
60 if (!fp) return kTfLiteError;
61
62 const int result_size = fwrite(data, 1, size, fp);
63 fclose(fp);
64 if (result_size != size) return kTfLiteError;
65
66 return kTfLiteOk;
67 }
68
CreateBuiltinUnion(flatbuffers::FlatBufferBuilder * fbb,enum BuiltinOperator op,void * builtin_op_data,const TfLiteNode & node)69 std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
70 flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
71 void* builtin_op_data, const TfLiteNode& node) {
72 switch (op) {
73 #include "tensorflow/lite/tools/serialization/option_writer_generated.h"
74 }
75 return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
76 }
77
78 } // namespace
79
80 template <class T_OUTPUT, class T_INPUT>
ExportVector(flatbuffers::FlatBufferBuilder * fbb,const T_INPUT & v)81 flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
82 flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
83 std::vector<T_OUTPUT> inputs(v.begin(), v.end());
84 return fbb->template CreateVector<T_OUTPUT>(inputs);
85 }
86
87 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Operator>>>
ExportOperators(flatbuffers::FlatBufferBuilder * fbb)88 SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
89 std::vector<flatbuffers::Offset<Operator>> operators;
90
91 std::vector<int> operator_to_opcode;
92 // TODO(aselle): Augment this once we put execution plan in schema.
93 operator_to_opcode.resize(subgraph_->nodes_size(), -1);
94 for (int op_index : execution_plan_) {
95 const auto* node_and_registration =
96 subgraph_->node_and_registration(op_index);
97 const TfLiteRegistration* registration = &node_and_registration->second;
98 if (!registration->custom_name) {
99 operator_to_opcode[op_index] =
100 GetOpCodeForBuiltin(registration->builtin_code);
101 } else {
102 operator_to_opcode[op_index] =
103 GetOpCodeForCustom(registration->custom_name);
104 }
105 }
106 // second pass serialize operators
107 for (int op_index : execution_plan_) {
108 const auto* node_and_registration =
109 subgraph_->node_and_registration(op_index);
110 const TfLiteNode& node = node_and_registration->first;
111 const TfLiteRegistration& registration = node_and_registration->second;
112 flatbuffers::Offset<void> builtin_options;
113 BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
114 // Custom data
115 // TODO(aselle): Custom options format is not known by default. Just assume
116 // for now.
117 auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
118 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_options = 0;
119
120 if (!registration.custom_name) {
121 // builtin
122 auto builtin_options_and_type = CreateBuiltinUnion(
123 fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
124 node.builtin_data, node);
125 builtin_options = builtin_options_and_type.second;
126 builtin_options_type = builtin_options_and_type.first;
127 } else {
128 auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
129 if (custom_writer != custom_op_to_writer_.end() &&
130 custom_writer->second) {
131 // delegate to custom writer if it exists
132 custom_writer->second(fbb, subgraph_, op_index, &custom_options,
133 &custom_options_format);
134 } else {
135 // use the custom data as fact
136 custom_options = fbb->CreateVector(
137 reinterpret_cast<const uint8_t*>(node.custom_initial_data),
138 node.custom_initial_data_size);
139 }
140 }
141
142 int opcode_index = operator_to_opcode[op_index];
143 std::vector<int> written_inputs =
144 RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
145 std::vector<int> written_outputs =
146 RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
147 auto inputs = ExportVector<int32_t>(fbb, written_inputs);
148 auto outputs = ExportVector<int32_t>(fbb, written_outputs);
149 operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
150 builtin_options_type, builtin_options,
151 custom_options, custom_options_format));
152 }
153
154 return fbb->template CreateVector<flatbuffers::Offset<Operator>>(operators);
155 }
156
157 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Tensor>>>
ExportTensors(flatbuffers::FlatBufferBuilder * fbb)158 SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
159 // Initialized to -1.
160 // A value of -1 means this tensor will not be exported.
161 tensor_to_written_tensor_.resize(subgraph_->tensors_size(), -1);
162
163 std::vector<flatbuffers::Offset<Tensor>> tensors;
164
165 // Make a map from tensor index to whether the tensor is a temporary.
166 std::vector<bool> tensor_is_temporary(subgraph_->tensors_size(), false);
167 for (int op_index = 0; op_index < subgraph_->nodes_size(); ++op_index) {
168 const auto* node_and_registration =
169 subgraph_->node_and_registration(op_index);
170 for (auto tensor_index :
171 TfLiteIntArrayView(node_and_registration->first.temporaries))
172 tensor_is_temporary[tensor_index] = true;
173 }
174
175 // Now we need to remap all used tensor indices
176 int curr_output_index = 0;
177 for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
178 tensor_index++) {
179 // Temporary tensors and unused tensors will not be written.
180 if (!tensor_is_temporary[tensor_index] &&
181 unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
182 tensor_to_written_tensor_[tensor_index] = curr_output_index++;
183 }
184 }
185
186 for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
187 ++tensor_index) {
188 // Tensor not exported.
189 if (tensor_to_written_tensor_[tensor_index] == -1) continue;
190
191 if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
192 // Allocate a buffer index
193 int buffer_index = 0; // This is null
194 if (tensor->allocation_type == kTfLiteMmapRo) {
195 buffer_index = buffers_->size();
196 buffers_->push_back(std::make_pair(
197 reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
198 }
199 // Primitive type.
200 TensorType type = TfLiteTypeToSchemaType(tensor->type);
201 // Handle quantization
202 flatbuffers::Offset<QuantizationParameters> quantization_params;
203
204 const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
205 flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
206 flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
207
208 if (tensor->quantization.type == kTfLiteAffineQuantization) {
209 if (tensor->params.scale != 0.f) {
210 // Quantization with a single argument array.
211 scale_array = fbb->CreateVector<float>({tensor->params.scale});
212 zero_point_array =
213 fbb->CreateVector<int64_t>({tensor->params.zero_point});
214 quantization_params = CreateQuantizationParameters(
215 *fbb, null_array, null_array, scale_array, zero_point_array);
216 } else { // Multi channel quantization.
217 const TfLiteAffineQuantization* params =
218 reinterpret_cast<TfLiteAffineQuantization*>(
219 tensor->quantization.params);
220 const size_t num_scales = params->scale->size;
221
222 std::vector<float> scale_vector(params->scale->data,
223 params->scale->data + num_scales);
224 std::vector<int64_t> zero_point_vector(
225 params->zero_point->data, params->zero_point->data + num_scales);
226 scale_array = fbb->CreateVector<float>(scale_vector);
227 zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
228 quantization_params = CreateQuantizationParameters(
229 *fbb, null_array, null_array, scale_array, zero_point_array,
230 QuantizationDetails_NONE, 0, params->quantized_dimension);
231 }
232 }
233
234 // Shape
235 TfLiteIntArrayView shape_view(tensor->dims);
236 std::vector<int> shape =
237 std::vector<int>(shape_view.begin(), shape_view.end());
238
239 tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
240 type, buffer_index,
241 fbb->CreateString(tensor->name),
242 quantization_params, tensor->is_variable));
243 }
244 }
245 return fbb->template CreateVector<flatbuffers::Offset<Tensor>>(tensors);
246 }
247
248 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)249 SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
250 return ExportBuffersImpl(fbb, buffers_);
251 }
252
253 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)254 SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
255 return CreateOpCodeTableImpl(fbb, opcodes_);
256 }
257
258 template <class T>
RemapTensorIndicesToWritten(const T & input)259 std::vector<int> SubgraphWriter::RemapTensorIndicesToWritten(const T& input) {
260 std::vector<int> output;
261 output.reserve(input.size());
262 for (int x : input) {
263 // Special value representing an optional tensor which is not present.
264 if (x == -1) {
265 output.push_back(x);
266 continue;
267 }
268 if (tensor_to_written_tensor_[x] != -1) {
269 output.push_back(tensor_to_written_tensor_[x]);
270 }
271 }
272 return output;
273 }
274
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)275 TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
276 size_t* size) {
277 if (!out || !size) return kTfLiteError;
278 flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
279 std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
280 subgraphs_as_vector.push_back(PopulateAndGetOffset(&builder));
281
282 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
283 buffers = ExportBuffers(&builder);
284
285 auto description = builder.CreateString("Exported from Subgraph.");
286
287 auto op_codes = CreateOpCodeTable(&builder);
288 auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
289 builder.CreateVector(subgraphs_as_vector),
290 description, buffers);
291 ::tflite::FinishModelBuffer(builder, model);
292 const uint8_t* buffer = builder.GetBufferPointer();
293 *size = builder.GetSize();
294 (*out).reset(new uint8_t[*size]);
295 memcpy(out->get(), buffer, *size);
296 return kTfLiteOk;
297 }
298
PopulateAndGetOffset(flatbuffers::FlatBufferBuilder * builder)299 flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
300 flatbuffers::FlatBufferBuilder* builder) {
301 auto tensors = ExportTensors(builder);
302 std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
303 std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
304 auto inputs = ExportVector<int32_t>(builder, written_inputs);
305 auto outputs = ExportVector<int32_t>(builder, written_outputs);
306
307 auto ops = ExportOperators(builder);
308 return CreateSubGraph(*builder, tensors, inputs, outputs, ops, /* name */ 0);
309 }
310
Write(const std::string & filename)311 TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
312 std::unique_ptr<uint8_t[]> buffer;
313 size_t size;
314 TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
315 return WriteImpl(filename, buffer.get(), size);
316 }
317
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)318 TfLiteStatus SubgraphWriter::RegisterCustomWriter(
319 const std::string& custom_name, CustomWriter custom_writer) {
320 if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
321 return kTfLiteError;
322 }
323 custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
324 return kTfLiteOk;
325 }
326
CheckInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)327 TfLiteStatus SubgraphWriter::CheckInputOutput(
328 const std::vector<int>& inputs, const std::vector<int>& outputs,
329 const std::vector<int>& execution_plan) {
330 absl::flat_hash_set<int> known_tensors(inputs.begin(), inputs.end());
331 known_tensors.insert(subgraph_->variables().begin(),
332 subgraph_->variables().end());
333 // Scan execution plan and confirm input tensors are known before each node
334 // executes. Then append output tensors to known tensors.
335 for (int op_index : execution_plan) {
336 const auto* node_and_registration =
337 subgraph_->node_and_registration(op_index);
338 const TfLiteNode& node = node_and_registration->first;
339 for (int tensor_index : TfLiteIntArrayView(node.inputs)) {
340 if (tensor_index < 0) {
341 // Skip if optional input not present.
342 if (tensor_index == kTfLiteOptionalTensor) {
343 continue;
344 } else {
345 return kTfLiteError;
346 }
347 }
348 if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
349 // Skip constant tensors.
350 if (tensor->allocation_type == kTfLiteMmapRo) {
351 continue;
352 }
353 }
354
355 if (known_tensors.find(tensor_index) == known_tensors.end()) {
356 subgraph_->context()->ReportError(
357 subgraph_->context(),
358 "Node (%d) uses an input (%d) that is not provided.", op_index,
359 tensor_index);
360 return kTfLiteError;
361 }
362 }
363 TfLiteIntArrayView outputs(node.outputs);
364 known_tensors.insert(outputs.begin(), outputs.end());
365 }
366
367 // Check if outputs are known tensors or constants.
368 for (int tensor_index : outputs) {
369 if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
370 // Skip constant tensors.
371 if (tensor->allocation_type == kTfLiteMmapRo) {
372 continue;
373 }
374 }
375
376 if (known_tensors.find(tensor_index) == known_tensors.end()) {
377 subgraph_->context()->ReportError(
378 subgraph_->context(),
379 "Output (%d) is not produced by the execution plan.", tensor_index);
380 return kTfLiteError;
381 }
382 }
383 return kTfLiteOk;
384 }
385
SetCustomInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)386 TfLiteStatus SubgraphWriter::SetCustomInputOutput(
387 const std::vector<int>& inputs, const std::vector<int>& outputs,
388 const std::vector<int>& execution_plan) {
389 TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan));
390 inputs_ = inputs;
391 outputs_ = outputs;
392 execution_plan_ = execution_plan;
393 return kTfLiteOk;
394 }
395
396 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)397 ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
398 return ExportBuffersImpl(fbb, &buffers_);
399 }
400
401 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)402 ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
403 return CreateOpCodeTableImpl(fbb, &opcodes_);
404 }
405
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)406 TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
407 size_t* size) {
408 if (!out || !size) return kTfLiteError;
409 flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
410
411 std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
412 for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
413 SubgraphWriter writer(interpreter_->subgraph(i), &buffers_, &opcodes_,
414 &builtin_op_to_opcode_);
415 subgraphs_as_vector.push_back(writer.PopulateAndGetOffset(&builder));
416 }
417
418 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
419 buffers = ExportBuffers(&builder);
420
421 auto description = builder.CreateString("Exported from Subgraph.");
422
423 auto op_codes = CreateOpCodeTable(&builder);
424 auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
425 builder.CreateVector(subgraphs_as_vector),
426 description, buffers);
427 ::tflite::FinishModelBuffer(builder, model);
428 const uint8_t* buffer = builder.GetBufferPointer();
429 *size = builder.GetSize();
430 (*out).reset(new uint8_t[*size]);
431 memcpy(out->get(), buffer, *size);
432 return kTfLiteOk;
433 }
434
Write(const std::string & filename)435 TfLiteStatus ModelWriter::Write(const std::string& filename) {
436 std::unique_ptr<uint8_t[]> buffer;
437 size_t size;
438 TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
439 return WriteImpl(filename, buffer.get(), size);
440 }
441
442 } // namespace tflite
443