• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/writer/writer_lib.h"
16 #include <cstdlib>
17 #include <cstring>
18 #include <unordered_map>
19 #include "tensorflow/lite/builtin_op_data.h"
20 #include "tensorflow/lite/context_util.h"
21 #include "tensorflow/lite/experimental/writer/enum_mapping.h"
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/schema/reflection/schema_generated.h"
24 #include "tensorflow/lite/version.h"
25 
26 namespace tflite {
27 template <class T>
28 using Offset = flatbuffers::Offset<T>;
29 template <class T>
30 using Vector = flatbuffers::Vector<T>;
31 using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
32 
CreateBuiltinUnion(FlatBufferBuilder * fbb,enum BuiltinOperator op,void * builtin_op_data)33 std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
34     FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
35   switch (op) {
36 #include "tensorflow/lite/experimental/writer/option_writer_generated.h"
37   }
38   return std::make_pair(BuiltinOptions_NONE, Offset<void>());
39 }
40 
41 template <class T_OUTPUT, class T_INPUT>
ExportVector(FlatBufferBuilder * fbb,const T_INPUT & v)42 Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
43                                                          const T_INPUT& v) {
44   std::vector<T_OUTPUT> inputs(v.begin(), v.end());
45   return fbb->template CreateVector<T_OUTPUT>(inputs);
46 }
47 
ExportOperators(FlatBufferBuilder * fbb)48 Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
49     FlatBufferBuilder* fbb) {
50   std::vector<Offset<Operator>> operators;
51 
52   std::vector<int> operator_to_opcode;
53   // TODO(aselle): Augment this once we put execution plan in schema.
54   operator_to_opcode.resize(interpreter_->nodes_size(), -1);
55   for (int op_index : interpreter_->execution_plan()) {
56     const auto* node_and_registration =
57         interpreter_->node_and_registration(op_index);
58     const TfLiteRegistration* registration = &node_and_registration->second;
59     if (!registration->custom_name) {
60       operator_to_opcode[op_index] =
61           GetOpCodeForBuiltin(registration->builtin_code);
62     } else {
63       operator_to_opcode[op_index] =
64           GetOpCodeForCustom(registration->custom_name);
65     }
66   }
67   // second pass serialize operators
68   for (int op_index : interpreter_->execution_plan()) {
69     const auto* node_and_registration =
70         interpreter_->node_and_registration(op_index);
71     const TfLiteNode& node = node_and_registration->first;
72     const TfLiteRegistration& registration = node_and_registration->second;
73     Offset<void> builtin_options;
74     BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
75     // Custom data
76     // TODO(aselle): Custom options format is not known by default. Just assume
77     // for now.
78     auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
79     Offset<Vector<uint8_t>> custom_options = 0;
80 
81     if (!registration.custom_name) {
82       // builtin
83       auto builtin_options_and_type = CreateBuiltinUnion(
84           fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
85           node.builtin_data);
86       builtin_options = builtin_options_and_type.second;
87       builtin_options_type = builtin_options_and_type.first;
88     } else {
89       auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
90       if (custom_writer != custom_op_to_writer_.end() &&
91           custom_writer->second) {
92         // delegate to custom writer if it exists
93         custom_writer->second(fbb, interpreter_, op_index, &custom_options,
94                               &custom_options_format);
95       } else {
96         // use the custom data as fact
97         custom_options = fbb->CreateVector(
98             reinterpret_cast<const uint8_t*>(node.custom_initial_data),
99             node.custom_initial_data_size);
100       }
101     }
102 
103     int opcode_index = operator_to_opcode[op_index];
104     std::vector<int> written_inputs =
105         RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
106     std::vector<int> written_outputs =
107         RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
108     auto inputs = ExportVector<int32_t>(fbb, written_inputs);
109     auto outputs = ExportVector<int32_t>(fbb, written_outputs);
110     operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
111                                        builtin_options_type, builtin_options,
112                                        custom_options, custom_options_format));
113   }
114 
115   return fbb->template CreateVector<Offset<Operator>>(operators);
116 }
117 
ExportTensors(FlatBufferBuilder * fbb)118 Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
119     FlatBufferBuilder* fbb) {
120   // Initialized to -1.
121   // A value of -1 means this tensor will not be exported.
122   tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
123 
124   std::vector<Offset<Tensor>> tensors;
125 
126   // Make a map from tensor index to whether the tensor is a temporary.
127   std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
128   for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
129     const auto* node_and_registration =
130         interpreter_->node_and_registration(op_index);
131     for (auto tensor_index :
132          TfLiteIntArrayView(node_and_registration->first.temporaries))
133       tensor_is_temporary[tensor_index] = true;
134   }
135 
136   // Now we need to remap all used tensor indices
137   int curr_output_index = 0;
138   for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
139        tensor_index++) {
140     // Temporary tensors and unused tensors will not be written.
141     if (!tensor_is_temporary[tensor_index] &&
142         unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
143       tensor_to_written_tensor_[tensor_index] = curr_output_index++;
144     }
145   }
146 
147   for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
148        ++tensor_index) {
149     // Tensor not exported.
150     if (tensor_to_written_tensor_[tensor_index] == -1) continue;
151 
152     if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
153       // We only need to convert non temporaries
154       if (tensor->allocation_type != kTfLiteArenaRw &&
155           tensor->allocation_type != kTfLiteMmapRo &&
156           tensor->allocation_type != kTfLiteArenaRwPersistent)
157         continue;
158       // Allocate a buffer index
159       int buffer_index = 0;  // This is null
160       if (tensor->allocation_type == kTfLiteMmapRo) {
161         buffer_index = buffers_.size();
162         buffers_.push_back(std::make_pair(
163             reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
164       }
165       // Primitive type.
166       TensorType type = TfLiteTypeToSchemaType(tensor->type);
167       // Handle quantization
168       const Offset<Vector<float>> null_array;
169       Offset<Vector<float>> scale_array;
170       Offset<Vector<int64_t>> zero_point_array;
171       if (tensor->params.scale != 0.f) {
172         // We have quantization, make a single arugment array (multi channel
173         // quant needs updating here).
174         scale_array = fbb->CreateVector<float>({tensor->params.scale});
175         zero_point_array =
176             fbb->CreateVector<int64_t>({tensor->params.zero_point});
177       }
178       Offset<QuantizationParameters> quantization_params =
179           CreateQuantizationParameters(*fbb, null_array, null_array,
180                                        scale_array, zero_point_array);
181       // Shape
182       TfLiteIntArrayView shape_view(tensor->dims);
183       std::vector<int> shape =
184           std::vector<int>(shape_view.begin(), shape_view.end());
185 
186       tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
187                                      type, buffer_index,
188                                      fbb->CreateString(tensor->name),
189                                      quantization_params, tensor->is_variable));
190     }
191   }
192   return fbb->template CreateVector<Offset<Tensor>>(tensors);
193 }
194 
ExportBuffers(FlatBufferBuilder * fbb)195 Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
196     FlatBufferBuilder* fbb) {
197   std::vector<Offset<Buffer>> buffer_vector;
198   for (auto buffer : buffers_) {
199     auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
200     buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
201   }
202   return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
203 }
204 
CreateOpCodeTable(FlatBufferBuilder * fbb)205 Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
206     FlatBufferBuilder* fbb) {
207   std::vector<Offset<OperatorCode>> codes;
208   for (auto it : opcodes_) {
209     const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
210     codes.push_back(CreateOperatorCodeDirect(
211         *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
212   }
213   return fbb->template CreateVector<Offset<OperatorCode>>(codes);
214 }
215 
216 template <class T>
RemapTensorIndicesToWritten(const T & input)217 std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
218     const T& input) {
219   std::vector<int> output;
220   output.reserve(input.size());
221   for (int x : input) {
222     // Special value representing an optional tensor which is not present.
223     if (x == -1) {
224       output.push_back(x);
225       continue;
226     }
227     if (tensor_to_written_tensor_[x] != -1) {
228       output.push_back(tensor_to_written_tensor_[x]);
229     }
230   }
231   return output;
232 }
233 
GetBuffer(std::unique_ptr<uint8_t[]> * out,size_t * size)234 TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
235                                           size_t* size) {
236   if (!out || !size) return kTfLiteError;
237   FlatBufferBuilder builder(/*initial_size=*/10240);
238 
239   std::vector<Offset<SubGraph>> subgraphs_as_vector;
240   {  // subgraph specific stuff
241     auto tensors = ExportTensors(&builder);
242     std::vector<int> written_inputs =
243         RemapTensorIndicesToWritten(interpreter_->inputs());
244     std::vector<int> written_outputs =
245         RemapTensorIndicesToWritten(interpreter_->outputs());
246     auto inputs = ExportVector<int32_t>(&builder, written_inputs);
247     auto outputs = ExportVector<int32_t>(&builder, written_outputs);
248 
249     auto ops = ExportOperators(&builder);
250     subgraphs_as_vector.push_back(
251         CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
252   }
253   Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
254 
255   auto description = builder.CreateString("Exported from Interpreter.");
256 
257   auto op_codes = CreateOpCodeTable(&builder);
258   auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
259                            builder.CreateVector(subgraphs_as_vector),
260                            description, buffers);
261   ::tflite::FinishModelBuffer(builder, model);
262   const uint8_t* buffer = builder.GetBufferPointer();
263   *size = builder.GetSize();
264   (*out).reset(new uint8_t[*size]);
265   memcpy(out->get(), buffer, *size);
266   return kTfLiteOk;
267 }
268 
Write(const std::string & filename)269 TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
270   std::unique_ptr<uint8_t[]> buffer;
271   size_t size;
272   TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
273 
274   FILE* fp = fopen(filename.c_str(), "wb");
275   if (!fp) return kTfLiteError;
276 
277   if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
278   if (fclose(fp)) return kTfLiteError;
279 
280   return kTfLiteOk;
281 }
282 
RegisterCustomWriter(const std::string & custom_name,CustomWriter custom_writer)283 TfLiteStatus InterpreterWriter::RegisterCustomWriter(
284     const std::string& custom_name, CustomWriter custom_writer) {
285   if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
286     return kTfLiteError;
287   }
288   custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
289   return kTfLiteOk;
290 }
291 
292 }  // namespace tflite
293