• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/serialization/writer_lib.h"
16 
17 #include <cstdlib>
18 #include <cstring>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/subgraph.h"
27 #include "tensorflow/lite/schema/reflection/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_conversion_utils.h"
29 #include "tensorflow/lite/tools/serialization/enum_mapping.h"
30 #include "tensorflow/lite/version.h"
31 
32 namespace tflite {
33 namespace {
34 
35 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<OpCode> * opcodes)36 CreateOpCodeTableImpl(flatbuffers::FlatBufferBuilder* fbb,
37                       std::vector<OpCode>* opcodes) {
38   std::vector<flatbuffers::Offset<OperatorCode>> codes;
39   for (const auto& it : *opcodes) {
40     const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
41     codes.push_back(CreateOperatorCodeDirect(
42         *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
43   }
44   return fbb->template CreateVector<flatbuffers::Offset<OperatorCode>>(codes);
45 }
46 
47 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffersImpl(flatbuffers::FlatBufferBuilder * fbb,std::vector<std::pair<const uint8_t *,size_t>> * buffers)48 ExportBuffersImpl(flatbuffers::FlatBufferBuilder* fbb,
49                   std::vector<std::pair<const uint8_t*, size_t>>* buffers) {
50   std::vector<flatbuffers::Offset<Buffer>> buffer_vector;
51   for (auto buffer : *buffers) {
52     auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
53     buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
54   }
55   return fbb->template CreateVector<flatbuffers::Offset<Buffer>>(buffer_vector);
56 }
57 
WriteImpl(const std::string & filename,void * data,size_t size)58 TfLiteStatus WriteImpl(const std::string& filename, void* data, size_t size) {
59   FILE* fp = fopen(filename.c_str(), "wb");
60   if (!fp) return kTfLiteError;
61 
62   const int result_size = fwrite(data, 1, size, fp);
63   fclose(fp);
64   if (result_size != size) return kTfLiteError;
65 
66   return kTfLiteOk;
67 }
68 
CreateBuiltinUnion(flatbuffers::FlatBufferBuilder * fbb,enum BuiltinOperator op,void * builtin_op_data,const TfLiteNode & node)69 std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
70     flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
71     void* builtin_op_data, const TfLiteNode& node) {
72   switch (op) {
73 #include "tensorflow/lite/tools/serialization/option_writer_generated.h"
74   }
75   return std::make_pair(BuiltinOptions_NONE, flatbuffers::Offset<void>());
76 }
77 
78 }  // namespace
79 
80 template <class T_OUTPUT, class T_INPUT>
ExportVector(flatbuffers::FlatBufferBuilder * fbb,const T_INPUT & v)81 flatbuffers::Offset<flatbuffers::Vector<T_OUTPUT>> SubgraphWriter::ExportVector(
82     flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v) {
83   std::vector<T_OUTPUT> inputs(v.begin(), v.end());
84   return fbb->template CreateVector<T_OUTPUT>(inputs);
85 }
86 
87 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Operator>>>
ExportOperators(flatbuffers::FlatBufferBuilder * fbb)88 SubgraphWriter::ExportOperators(flatbuffers::FlatBufferBuilder* fbb) {
89   std::vector<flatbuffers::Offset<Operator>> operators;
90 
91   std::vector<int> operator_to_opcode;
92   // TODO(aselle): Augment this once we put execution plan in schema.
93   operator_to_opcode.resize(subgraph_->nodes_size(), -1);
94   for (int op_index : execution_plan_) {
95     const auto* node_and_registration =
96         subgraph_->node_and_registration(op_index);
97     const TfLiteRegistration* registration = &node_and_registration->second;
98     if (!registration->custom_name) {
99       operator_to_opcode[op_index] =
100           GetOpCodeForBuiltin(registration->builtin_code);
101     } else {
102       operator_to_opcode[op_index] =
103           GetOpCodeForCustom(registration->custom_name);
104     }
105   }
106   // second pass serialize operators
107   for (int op_index : execution_plan_) {
108     const auto* node_and_registration =
109         subgraph_->node_and_registration(op_index);
110     const TfLiteNode& node = node_and_registration->first;
111     const TfLiteRegistration& registration = node_and_registration->second;
112     flatbuffers::Offset<void> builtin_options;
113     BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
114     // Custom data
115     // TODO(aselle): Custom options format is not known by default. Just assume
116     // for now.
117     auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
118     flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_options = 0;
119 
120     if (!registration.custom_name) {
121       // builtin
122       auto builtin_options_and_type = CreateBuiltinUnion(
123           fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
124           node.builtin_data, node);
125       builtin_options = builtin_options_and_type.second;
126       builtin_options_type = builtin_options_and_type.first;
127     } else {
128       auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
129       if (custom_writer != custom_op_to_writer_.end() &&
130           custom_writer->second) {
131         // delegate to custom writer if it exists
132         custom_writer->second(fbb, subgraph_, op_index, &custom_options,
133                               &custom_options_format);
134       } else {
135         // use the custom data as fact
136         custom_options = fbb->CreateVector(
137             reinterpret_cast<const uint8_t*>(node.custom_initial_data),
138             node.custom_initial_data_size);
139       }
140     }
141 
142     int opcode_index = operator_to_opcode[op_index];
143     std::vector<int> written_inputs =
144         RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
145     std::vector<int> written_outputs =
146         RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
147     auto inputs = ExportVector<int32_t>(fbb, written_inputs);
148     auto outputs = ExportVector<int32_t>(fbb, written_outputs);
149     operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
150                                        builtin_options_type, builtin_options,
151                                        custom_options, custom_options_format));
152   }
153 
154   return fbb->template CreateVector<flatbuffers::Offset<Operator>>(operators);
155 }
156 
157 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Tensor>>>
ExportTensors(flatbuffers::FlatBufferBuilder * fbb)158 SubgraphWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
159   // Initialized to -1.
160   // A value of -1 means this tensor will not be exported.
161   tensor_to_written_tensor_.resize(subgraph_->tensors_size(), -1);
162 
163   std::vector<flatbuffers::Offset<Tensor>> tensors;
164 
165   // Make a map from tensor index to whether the tensor is a temporary.
166   std::vector<bool> tensor_is_temporary(subgraph_->tensors_size(), false);
167   for (int op_index = 0; op_index < subgraph_->nodes_size(); ++op_index) {
168     const auto* node_and_registration =
169         subgraph_->node_and_registration(op_index);
170     for (auto tensor_index :
171          TfLiteIntArrayView(node_and_registration->first.temporaries))
172       tensor_is_temporary[tensor_index] = true;
173   }
174 
175   // Now we need to remap all used tensor indices
176   int curr_output_index = 0;
177   for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
178        tensor_index++) {
179     // Temporary tensors and unused tensors will not be written.
180     if (!tensor_is_temporary[tensor_index] &&
181         unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
182       tensor_to_written_tensor_[tensor_index] = curr_output_index++;
183     }
184   }
185 
186   for (int tensor_index = 0; tensor_index < subgraph_->tensors_size();
187        ++tensor_index) {
188     // Tensor not exported.
189     if (tensor_to_written_tensor_[tensor_index] == -1) continue;
190 
191     if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
192       // Allocate a buffer index
193       int buffer_index = 0;  // This is null
194       if (tensor->allocation_type == kTfLiteMmapRo) {
195         buffer_index = buffers_->size();
196         buffers_->push_back(std::make_pair(
197             reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
198       }
199       // Primitive type.
200       TensorType type = TfLiteTypeToSchemaType(tensor->type);
201       // Handle quantization
202       flatbuffers::Offset<QuantizationParameters> quantization_params;
203 
204       const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
205       flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
206       flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
207 
208       if (tensor->quantization.type == kTfLiteAffineQuantization) {
209         if (tensor->params.scale != 0.f) {
210           // Quantization with a single argument array.
211           scale_array = fbb->CreateVector<float>({tensor->params.scale});
212           zero_point_array =
213               fbb->CreateVector<int64_t>({tensor->params.zero_point});
214           quantization_params = CreateQuantizationParameters(
215               *fbb, null_array, null_array, scale_array, zero_point_array);
216         } else {  // Multi channel quantization.
217           const TfLiteAffineQuantization* params =
218               reinterpret_cast<TfLiteAffineQuantization*>(
219                   tensor->quantization.params);
220           const size_t num_scales = params->scale->size;
221 
222           std::vector<float> scale_vector(params->scale->data,
223                                           params->scale->data + num_scales);
224           std::vector<int64_t> zero_point_vector(
225               params->zero_point->data, params->zero_point->data + num_scales);
226           scale_array = fbb->CreateVector<float>(scale_vector);
227           zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
228           quantization_params = CreateQuantizationParameters(
229               *fbb, null_array, null_array, scale_array, zero_point_array,
230               QuantizationDetails_NONE, 0, params->quantized_dimension);
231         }
232       }
233 
234       // Shape
235       TfLiteIntArrayView shape_view(tensor->dims);
236       std::vector<int> shape =
237           std::vector<int>(shape_view.begin(), shape_view.end());
238 
239       tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
240                                      type, buffer_index,
241                                      fbb->CreateString(tensor->name),
242                                      quantization_params, tensor->is_variable));
243     }
244   }
245   return fbb->template CreateVector<flatbuffers::Offset<Tensor>>(tensors);
246 }
247 
248 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)249 SubgraphWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
250   return ExportBuffersImpl(fbb, buffers_);
251 }
252 
253 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)254 SubgraphWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
255   return CreateOpCodeTableImpl(fbb, opcodes_);
256 }
257 
258 template <class T>
RemapTensorIndicesToWritten(const T & input)259 std::vector<int> SubgraphWriter::RemapTensorIndicesToWritten(const T& input) {
260   std::vector<int> output;
261   output.reserve(input.size());
262   for (int x : input) {
263     // Special value representing an optional tensor which is not present.
264     if (x == -1) {
265       output.push_back(x);
266       continue;
267     }
268     if (tensor_to_written_tensor_[x] != -1) {
269       output.push_back(tensor_to_written_tensor_[x]);
270     }
271   }
272   return output;
273 }
274 
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)275 TfLiteStatus SubgraphWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
276                                        size_t* size) {
277   if (!out || !size) return kTfLiteError;
278   flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
279   std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
280   subgraphs_as_vector.push_back(PopulateAndGetOffset(&builder));
281 
282   flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
283       buffers = ExportBuffers(&builder);
284 
285   auto description = builder.CreateString("Exported from Subgraph.");
286 
287   auto op_codes = CreateOpCodeTable(&builder);
288   auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
289                            builder.CreateVector(subgraphs_as_vector),
290                            description, buffers);
291   ::tflite::FinishModelBuffer(builder, model);
292   const uint8_t* buffer = builder.GetBufferPointer();
293   *size = builder.GetSize();
294   (*out).reset(new uint8_t[*size]);
295   memcpy(out->get(), buffer, *size);
296   return kTfLiteOk;
297 }
298 
PopulateAndGetOffset(flatbuffers::FlatBufferBuilder * builder)299 flatbuffers::Offset<SubGraph> SubgraphWriter::PopulateAndGetOffset(
300     flatbuffers::FlatBufferBuilder* builder) {
301   auto tensors = ExportTensors(builder);
302   std::vector<int> written_inputs = RemapTensorIndicesToWritten(inputs_);
303   std::vector<int> written_outputs = RemapTensorIndicesToWritten(outputs_);
304   auto inputs = ExportVector<int32_t>(builder, written_inputs);
305   auto outputs = ExportVector<int32_t>(builder, written_outputs);
306 
307   auto ops = ExportOperators(builder);
308   return CreateSubGraph(*builder, tensors, inputs, outputs, ops, /* name */ 0);
309 }
310 
Write(const std::string & filename)311 TfLiteStatus SubgraphWriter::Write(const std::string& filename) {
312   std::unique_ptr<uint8_t[]> buffer;
313   size_t size;
314   TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
315   return WriteImpl(filename, buffer.get(), size);
316 }
317 
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)318 TfLiteStatus SubgraphWriter::RegisterCustomWriter(
319     const std::string& custom_name, CustomWriter custom_writer) {
320   if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
321     return kTfLiteError;
322   }
323   custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
324   return kTfLiteOk;
325 }
326 
CheckInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)327 TfLiteStatus SubgraphWriter::CheckInputOutput(
328     const std::vector<int>& inputs, const std::vector<int>& outputs,
329     const std::vector<int>& execution_plan) {
330   absl::flat_hash_set<int> known_tensors(inputs.begin(), inputs.end());
331   known_tensors.insert(subgraph_->variables().begin(),
332                        subgraph_->variables().end());
333   // Scan execution plan and confirm input tensors are known before each node
334   // executes. Then append output tensors to known tensors.
335   for (int op_index : execution_plan) {
336     const auto* node_and_registration =
337         subgraph_->node_and_registration(op_index);
338     const TfLiteNode& node = node_and_registration->first;
339     for (int tensor_index : TfLiteIntArrayView(node.inputs)) {
340       if (tensor_index < 0) {
341         // Skip if optional input not present.
342         if (tensor_index == kTfLiteOptionalTensor) {
343           continue;
344         } else {
345           return kTfLiteError;
346         }
347       }
348       if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
349         // Skip constant tensors.
350         if (tensor->allocation_type == kTfLiteMmapRo) {
351           continue;
352         }
353       }
354 
355       if (known_tensors.find(tensor_index) == known_tensors.end()) {
356         subgraph_->context()->ReportError(
357             subgraph_->context(),
358             "Node (%d) uses an input (%d) that is not provided.", op_index,
359             tensor_index);
360         return kTfLiteError;
361       }
362     }
363     TfLiteIntArrayView outputs(node.outputs);
364     known_tensors.insert(outputs.begin(), outputs.end());
365   }
366 
367   // Check if outputs are known tensors or constants.
368   for (int tensor_index : outputs) {
369     if (TfLiteTensor* tensor = subgraph_->tensor(tensor_index)) {
370       // Skip constant tensors.
371       if (tensor->allocation_type == kTfLiteMmapRo) {
372         continue;
373       }
374     }
375 
376     if (known_tensors.find(tensor_index) == known_tensors.end()) {
377       subgraph_->context()->ReportError(
378           subgraph_->context(),
379           "Output (%d) is not produced by the execution plan.", tensor_index);
380       return kTfLiteError;
381     }
382   }
383   return kTfLiteOk;
384 }
385 
SetCustomInputOutput(const std::vector<int> & inputs,const std::vector<int> & outputs,const std::vector<int> & execution_plan)386 TfLiteStatus SubgraphWriter::SetCustomInputOutput(
387     const std::vector<int>& inputs, const std::vector<int>& outputs,
388     const std::vector<int>& execution_plan) {
389   TF_LITE_ENSURE_STATUS(CheckInputOutput(inputs, outputs, execution_plan));
390   inputs_ = inputs;
391   outputs_ = outputs;
392   execution_plan_ = execution_plan;
393   return kTfLiteOk;
394 }
395 
396 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
ExportBuffers(flatbuffers::FlatBufferBuilder * fbb)397 ModelWriter::ExportBuffers(flatbuffers::FlatBufferBuilder* fbb) {
398   return ExportBuffersImpl(fbb, &buffers_);
399 }
400 
401 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>>
CreateOpCodeTable(flatbuffers::FlatBufferBuilder * fbb)402 ModelWriter::CreateOpCodeTable(flatbuffers::FlatBufferBuilder* fbb) {
403   return CreateOpCodeTableImpl(fbb, &opcodes_);
404 }
405 
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)406 TfLiteStatus ModelWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
407                                     size_t* size) {
408   if (!out || !size) return kTfLiteError;
409   flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
410 
411   std::vector<flatbuffers::Offset<SubGraph>> subgraphs_as_vector;
412   for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
413     SubgraphWriter writer(interpreter_->subgraph(i), &buffers_, &opcodes_,
414                           &builtin_op_to_opcode_);
415     subgraphs_as_vector.push_back(writer.PopulateAndGetOffset(&builder));
416   }
417 
418   flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
419       buffers = ExportBuffers(&builder);
420 
421   auto description = builder.CreateString("Exported from Subgraph.");
422 
423   auto op_codes = CreateOpCodeTable(&builder);
424   auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
425                            builder.CreateVector(subgraphs_as_vector),
426                            description, buffers);
427   ::tflite::FinishModelBuffer(builder, model);
428   const uint8_t* buffer = builder.GetBufferPointer();
429   *size = builder.GetSize();
430   (*out).reset(new uint8_t[*size]);
431   memcpy(out->get(), buffer, *size);
432   return kTfLiteOk;
433 }
434 
Write(const std::string & filename)435 TfLiteStatus ModelWriter::Write(const std::string& filename) {
436   std::unique_ptr<uint8_t[]> buffer;
437   size_t size;
438   TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
439   return WriteImpl(filename, buffer.get(), size);
440 }
441 
442 }  // namespace tflite
443