• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/import.h"
16 
17 #include "flatbuffers/flexbuffers.h"
18 #include "tensorflow/lite/model.h"
19 #include "tensorflow/lite/schema/schema_generated.h"
20 #include "tensorflow/lite/toco/tflite/operator.h"
21 #include "tensorflow/lite/toco/tflite/types.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/lite/tools/verifier.h"
24 
25 namespace toco {
26 
27 namespace tflite {
28 
29 namespace details {
LoadTensorsTable(const::tflite::Model & input_model,TensorsTable * tensors_table)30 void LoadTensorsTable(const ::tflite::Model& input_model,
31                       TensorsTable* tensors_table) {
32   // TODO(aselle): add support to toco for multiple subgraphs.
33   auto tensors = (*input_model.subgraphs())[0]->tensors();
34   if (!tensors) return;
35   for (const auto* tensor : *tensors) {
36     tensors_table->push_back(tensor->name()->c_str());
37   }
38 }
39 
LoadOperatorsTable(const::tflite::Model & input_model,OperatorsTable * operators_table)40 void LoadOperatorsTable(const ::tflite::Model& input_model,
41                         OperatorsTable* operators_table) {
42   auto opcodes = input_model.operator_codes();
43   if (!opcodes) return;
44   for (const auto* opcode : *opcodes) {
45     if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
46       operators_table->push_back(
47           EnumNameBuiltinOperator(opcode->builtin_code()));
48     } else {
49       operators_table->push_back(opcode->custom_code()->c_str());
50     }
51   }
52 }
53 }  // namespace details
54 
ImportTensors(const::tflite::Model & input_model,Model * model)55 void ImportTensors(const ::tflite::Model& input_model, Model* model) {
56   auto tensors = (*input_model.subgraphs())[0]->tensors();
57   auto* buffers = input_model.buffers();
58   // auto tensors = input_model.tensors();
59   if (!tensors) return;
60   for (const auto* input_tensor : *tensors) {
61     Array& array = model->GetOrCreateArray(input_tensor->name()->c_str());
62     array.data_type = DataType::Deserialize(input_tensor->type());
63     int buffer_index = input_tensor->buffer();
64     auto* buffer = buffers->Get(buffer_index);
65     DataBuffer::Deserialize(*input_tensor, *buffer, &array);
66 
67     auto shape = input_tensor->shape();
68     if (shape) {
69       // If the shape is 0-dimensional, make sure to record it as such,
70       // as oppose to leaving the array without a shape.
71       array.mutable_shape()->mutable_dims()->clear();
72       for (uint32_t i = 0; i < shape->Length(); ++i) {
73         auto d = shape->Get(i);
74         array.mutable_shape()->mutable_dims()->push_back(d);
75       }
76     }
77 
78     auto quantization = input_tensor->quantization();
79     if (quantization) {
80       // Note that tf.mini only supports a single quantization parameters for
81       // the whole array.
82       if (quantization->min() && quantization->max()) {
83         CHECK_EQ(1, quantization->min()->Length());
84         CHECK_EQ(1, quantization->max()->Length());
85         MinMax& minmax = array.GetOrCreateMinMax();
86         minmax.min = quantization->min()->Get(0);
87         minmax.max = quantization->max()->Get(0);
88       }
89       if (quantization->scale() && quantization->zero_point()) {
90         CHECK_EQ(1, quantization->scale()->Length());
91         CHECK_EQ(1, quantization->zero_point()->Length());
92         QuantizationParams& q = array.GetOrCreateQuantizationParams();
93         q.scale = quantization->scale()->Get(0);
94         q.zero_point = quantization->zero_point()->Get(0);
95       }
96     }
97   }
98 }
99 
ImportOperators(const::tflite::Model & input_model,const std::map<string,std::unique_ptr<BaseOperator>> & ops_by_name,const details::TensorsTable & tensors_table,const details::OperatorsTable & operators_table,Model * model)100 void ImportOperators(
101     const ::tflite::Model& input_model,
102     const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name,
103     const details::TensorsTable& tensors_table,
104     const details::OperatorsTable& operators_table, Model* model) {
105   // TODO(aselle): add support for multiple subgraphs.
106   auto ops = (*input_model.subgraphs())[0]->operators();
107 
108   if (!ops) return;
109   for (const auto* input_op : *ops) {
110     uint32_t index = input_op->opcode_index();
111     if (index > operators_table.size()) {
112       LOG(FATAL) << "Index " << index << " must be between zero and "
113                  << operators_table.size();
114     }
115     string opname = operators_table.at(index);
116 
117     // Find and use the appropriate operator deserialization factory.
118     std::unique_ptr<Operator> new_op = nullptr;
119     if (ops_by_name.count(opname) == 0) {
120       string effective_opname = "TENSORFLOW_UNSUPPORTED";
121       if (ops_by_name.count(effective_opname) == 0) {
122         LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found.";
123       }
124       new_op = ops_by_name.at(effective_opname)
125                    ->Deserialize(input_op->builtin_options(),
126                                  input_op->custom_options());
127       if (new_op->type == OperatorType::kUnsupported) {
128         auto* unsupported_op =
129             static_cast<TensorFlowUnsupportedOperator*>(new_op.get());
130         unsupported_op->tensorflow_op = opname;
131         // TODO(b/109932940): Remove this when quantized is removed.
132         // For now, we assume all ops are quantized.
133         unsupported_op->quantized = true;
134       } else {
135         LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator";
136       }
137     } else {
138       new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(),
139                                                    input_op->custom_options());
140     }
141     model->operators.emplace_back(new_op.release());
142     auto* op = model->operators.back().get();
143 
144     // Make sure all the inputs and outputs are hooked up.
145     auto inputs = input_op->inputs();
146     for (uint32_t i = 0; i < inputs->Length(); i++) {
147       auto input_index = inputs->Get(i);
148       // input_index == -1 indicates optional tensor.
149       if (input_index != -1) {
150         const string& input_name = tensors_table.at(input_index);
151         op->inputs.push_back(input_name);
152       } else {
153         const string& tensor_name =
154             toco::AvailableArrayName(*model, "OptionalTensor");
155         model->CreateOptionalArray(tensor_name);
156         op->inputs.push_back(tensor_name);
157       }
158     }
159     auto outputs = input_op->outputs();
160     for (int i = 0; i < outputs->Length(); i++) {
161       auto output_index = outputs->Get(i);
162       const string& output_name = tensors_table.at(output_index);
163       op->outputs.push_back(output_name);
164     }
165   }
166 }
167 
ImportIOTensors(const ModelFlags & model_flags,const::tflite::Model & input_model,const details::TensorsTable & tensors_table,Model * model)168 void ImportIOTensors(const ModelFlags& model_flags,
169                      const ::tflite::Model& input_model,
170                      const details::TensorsTable& tensors_table, Model* model) {
171   // Import from the first subgraph if input arrays have not been specified.
172   if (model_flags.input_arrays().empty()) {
173     auto inputs = (*input_model.subgraphs())[0]->inputs();
174     if (inputs) {
175       for (int input : *inputs) {
176         const string& input_name = tensors_table.at(input);
177         model->flags.add_input_arrays()->set_name(input_name);
178       }
179     }
180   }
181 
182   // Import from the first subgraph if output arrays have not been specified.
183   if (model_flags.output_arrays().empty()) {
184     auto outputs = (*input_model.subgraphs())[0]->outputs();
185     if (outputs) {
186       for (int output : *outputs) {
187         const string& output_name = tensors_table.at(output);
188         model->flags.add_output_arrays(output_name);
189       }
190     }
191   }
192 }
193 
194 namespace {
Verify(const void * buf,size_t len)195 bool Verify(const void* buf, size_t len) {
196   ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
197   return ::tflite::VerifyModelBuffer(verifier);
198 }
199 }  // namespace
200 
Import(const ModelFlags & model_flags,const string & input_file_contents)201 std::unique_ptr<Model> Import(const ModelFlags& model_flags,
202                               const string& input_file_contents) {
203   ::tflite::AlwaysTrueResolver r;
204   if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(),
205                         r, ::tflite::DefaultErrorReporter())) {
206     LOG(FATAL) << "Invalid flatbuffer.";
207   }
208   const ::tflite::Model* input_model =
209       ::tflite::GetModel(input_file_contents.data());
210 
211   // Full list of all known operators.
212   const auto ops_by_name = BuildOperatorByNameMap();
213 
214   if (!input_model->subgraphs() || input_model->subgraphs()->size() != 1) {
215     LOG(FATAL) << "Number of subgraphs in tflite should be exactly 1.";
216   }
217   std::unique_ptr<Model> model;
218   model.reset(new Model);
219 
220   details::TensorsTable tensors_table;
221   details::LoadTensorsTable(*input_model, &tensors_table);
222 
223   details::OperatorsTable operators_table;
224   details::LoadOperatorsTable(*input_model, &operators_table);
225 
226   ImportTensors(*input_model, model.get());
227   ImportOperators(*input_model, ops_by_name, tensors_table, operators_table,
228                   model.get());
229 
230   ImportIOTensors(model_flags, *input_model, tensors_table, model.get());
231 
232   UndoWeightsShuffling(model.get());
233 
234   return model;
235 }
236 
237 }  // namespace tflite
238 
239 }  // namespace toco
240