• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <fcntl.h>
16 #include <stdint.h>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <sys/mman.h>
20 #include <sys/stat.h>
21 #include <sys/types.h>
22 #include <unistd.h>
23 
24 #include "tensorflow/contrib/lite/allocation.h"
25 #include "tensorflow/contrib/lite/builtin_op_data.h"
26 #include "tensorflow/contrib/lite/error_reporter.h"
27 #include "tensorflow/contrib/lite/model.h"
28 #include "tensorflow/contrib/lite/nnapi_delegate.h"
29 #include "tensorflow/contrib/lite/version.h"
30 
31 namespace tflite {
32 
33 const char* kEmptyTensorName = "";
34 
BuildFromFile(const char * filename,ErrorReporter * error_reporter)35 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
36     const char* filename, ErrorReporter* error_reporter) {
37   std::unique_ptr<FlatBufferModel> model;
38   model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
39                                   /*use_nnapi=*/true));
40   if (!model->initialized()) model.reset();
41   return model;
42 }
43 
BuildFromBuffer(const char * buffer,size_t buffer_size,ErrorReporter * error_reporter)44 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
45     const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
46   std::unique_ptr<FlatBufferModel> model;
47   model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
48   if (!model->initialized()) model.reset();
49   return model;
50 }
51 
BuildFromModel(const tflite::Model * model_spec,ErrorReporter * error_reporter)52 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
53     const tflite::Model* model_spec, ErrorReporter* error_reporter) {
54   std::unique_ptr<FlatBufferModel> model;
55   model.reset(new FlatBufferModel(model_spec, error_reporter));
56   if (!model->initialized()) model.reset();
57   return model;
58 }
59 
FlatBufferModel(const char * filename,bool mmap_file,ErrorReporter * error_reporter,bool use_nnapi)60 FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
61                                  ErrorReporter* error_reporter, bool use_nnapi)
62     : error_reporter_(error_reporter ? error_reporter
63                                      : DefaultErrorReporter()) {
64   if (mmap_file) {
65     if (use_nnapi && NNAPIExists())
66       allocation_ = new NNAPIAllocation(filename, error_reporter);
67     else
68       allocation_ = new MMAPAllocation(filename, error_reporter);
69   } else {
70     allocation_ = new FileCopyAllocation(filename, error_reporter);
71   }
72   if (!allocation_->valid() || !CheckModelIdentifier()) return;
73 
74   model_ = ::tflite::GetModel(allocation_->base());
75 }
76 
CheckModelIdentifier() const77 bool FlatBufferModel::CheckModelIdentifier() const {
78   if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
79     const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
80     error_reporter_->Report(
81         "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
82         ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
83     return false;
84   }
85   return true;
86 }
87 
FlatBufferModel(const char * ptr,size_t num_bytes,ErrorReporter * error_reporter)88 FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
89                                  ErrorReporter* error_reporter)
90     : error_reporter_(error_reporter ? error_reporter
91                                      : DefaultErrorReporter()) {
92   allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
93   if (!allocation_->valid()) return;
94 
95   model_ = ::tflite::GetModel(allocation_->base());
96 }
97 
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)98 FlatBufferModel::FlatBufferModel(const Model* model,
99                                  ErrorReporter* error_reporter)
100     : error_reporter_(error_reporter ? error_reporter
101                                      : DefaultErrorReporter()) {
102   model_ = model;
103 }
104 
~FlatBufferModel()105 FlatBufferModel::~FlatBufferModel() { delete allocation_; }
106 
InterpreterBuilder(const FlatBufferModel & model,const OpResolver & op_resolver)107 InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
108                                        const OpResolver& op_resolver)
109     : model_(model.GetModel()),
110       op_resolver_(op_resolver),
111       error_reporter_(model.error_reporter()),
112       allocation_(model.allocation()) {}
113 
InterpreterBuilder(const::tflite::Model * model,const OpResolver & op_resolver,ErrorReporter * error_reporter)114 InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
115                                        const OpResolver& op_resolver,
116                                        ErrorReporter* error_reporter)
117     : model_(model),
118       op_resolver_(op_resolver),
119       error_reporter_(error_reporter ? error_reporter
120                                      : DefaultErrorReporter()) {}
121 
BuildLocalIndexToRegistrationMapping()122 TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
123   TfLiteStatus status = kTfLiteOk;
124   auto opcodes = model_->operator_codes();
125   for (const OperatorCode* opcode : *opcodes) {
126     TfLiteRegistration* registration = nullptr;
127 
128     if (opcode->builtin_code() != BuiltinOperator_CUSTOM) {
129       auto x = opcode->builtin_code();
130       flatbuffer_op_index_to_registration_types_.push_back(x);
131       registration = op_resolver_.FindOp(x);
132       if (registration == nullptr) {
133         error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
134                                 EnumNameBuiltinOperator(x));
135         status = kTfLiteError;
136       }
137     } else if (!opcode->custom_code()) {
138       error_reporter_->Report(
139           "Operator with CUSTOM builtin_code has no custom_code.\n");
140       status = kTfLiteError;
141     } else {
142       const char* name = opcode->custom_code()->c_str();
143       registration = op_resolver_.FindOp(name);
144       flatbuffer_op_index_to_registration_types_.push_back(
145           BuiltinOperator_CUSTOM);
146       if (registration == nullptr) {
147         error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
148         status = kTfLiteError;
149       }
150     }
151     flatbuffer_op_index_to_registration_.push_back(registration);
152   }
153   return status;
154 }
155 
156 namespace {
157 template <class T>
FlatBufferIntArrayToVector(T * flat_array)158 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
159   std::vector<int> ret(flat_array->Length());
160   for (int i = 0; i < flat_array->Length(); i++) {
161     ret[i] = flat_array->Get(i);
162   }
163   return ret;
164 }
165 
166 // Copies the contents from the flatbuffer int vector `flatbuffer` into the
167 // int array `buffer`. `flat_vector` and `buffer` represent the same
168 // configuration operation for a given operation.
FlatBufferIntVectorToArray(int max_size_of_buffer,const flatbuffers::Vector<int32_t> * flat_vector,int * buffer,ErrorReporter * error_reporter)169 void FlatBufferIntVectorToArray(int max_size_of_buffer,
170                                 const flatbuffers::Vector<int32_t>* flat_vector,
171                                 int* buffer, ErrorReporter* error_reporter) {
172   if (!flat_vector) {
173     error_reporter->Report("Input array not provided for operation.\n");
174   } else {
175     int num_dimensions = flat_vector->Length();
176     if (num_dimensions > max_size_of_buffer / sizeof(int)) {
177       error_reporter->Report(
178           "Found too many dimensions in the operation's input array.\n");
179     } else {
180       for (int i = 0; i < num_dimensions; ++i) {
181         buffer[i] = flat_vector->Get(i);
182       }
183     }
184   }
185 }
186 
187 // Allocate a structure using C malloc, but make sure the structure is a
188 // POD structure that doesn't require constructors to run. The reason we do
189 // this, is that Interpreter's C extension part will take ownership and wants
190 // to use malloc() and free().
191 template <class T>
MallocPOD()192 T* MallocPOD() {
193   static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
194   return static_cast<T*>(malloc(sizeof(T)));
195 }
196 
197 // Parse the appropriate data out of the op.
198 //
199 // This handles builtin data explicitly as there are flatbuffer schemas.
200 //
201 // Returns memory that must be feed.
202 //
203 // TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program
204 // crashes if error reporter is called.
ParseOpData(const Operator * op,BuiltinOperator op_type,ErrorReporter * error_reporter)205 void* ParseOpData(const Operator* op, BuiltinOperator op_type,
206                   ErrorReporter* error_reporter) {
207   auto parse_padding = [](Padding padding) {
208     switch (padding) {
209       case Padding_SAME:
210         return kTfLitePaddingSame;
211       case Padding_VALID:
212         return kTfLitePaddingValid;
213     }
214     return kTfLitePaddingUnknown;
215   };
216   auto parse_activation = [](ActivationFunctionType activation) {
217     switch (activation) {
218       case ActivationFunctionType_NONE:
219         return kTfLiteActNone;
220       case ActivationFunctionType_RELU:
221         return kTfLiteActRelu;
222       case ActivationFunctionType_RELU_N1_TO_1:
223         return kTfLiteActRelu1;
224       case ActivationFunctionType_RELU6:
225         return kTfLiteActRelu6;
226       case ActivationFunctionType_TANH:
227         return kTfLiteActTanh;
228       case ActivationFunctionType_SIGN_BIT:
229         return kTfLiteActSignBit;
230     }
231     return kTfLiteActNone;
232   };
233   auto parseLSHProjectionType = [](LSHProjectionType type) {
234     switch (type) {
235       case LSHProjectionType_SPARSE:
236         return kTfLiteLshProjectionSparse;
237       case LSHProjectionType_DENSE:
238         return kTfLiteLshProjectionDense;
239       default:
240         return kTfLiteLshProjectionUnknown;
241     }
242   };
243   auto parseCombinerType = [](CombinerType type) {
244     switch (type) {
245       case CombinerType_MEAN:
246         return kTfLiteCombinerTypeMean;
247       case CombinerType_SQRTN:
248         return kTfLiteCombinerTypeSqrtn;
249       case CombinerType_SUM:
250       default:
251         return kTfLiteCombinerTypeSum;
252     }
253   };
254 
255   void* builtin_data = nullptr;
256   switch (op_type) {
257     case BuiltinOperator_CALL:
258       // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
259       // ok for now, since there is no call implementation either.
260       break;
261     case BuiltinOperator_CUSTOM:
262       break;
263     case BuiltinOperator_CONV_2D: {
264       TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
265       if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
266         params->padding = parse_padding(conv_params->padding());
267         params->stride_width = conv_params->stride_w();
268         params->stride_height = conv_params->stride_h();
269         params->activation =
270             parse_activation(conv_params->fused_activation_function());
271       }
272       builtin_data = reinterpret_cast<void*>(params);
273       break;
274     }
275     case BuiltinOperator_TANH:
276     case BuiltinOperator_LOGISTIC:
277     case BuiltinOperator_RELU:
278     case BuiltinOperator_RELU_N1_TO_1:
279     case BuiltinOperator_RELU6:
280     case BuiltinOperator_CONCAT_EMBEDDINGS:
281     case BuiltinOperator_EXP:
282     case BuiltinOperator_TOPK_V2:
283       break;
284     case BuiltinOperator_LSH_PROJECTION: {
285       TfLiteLSHProjectionParams* params =
286           MallocPOD<TfLiteLSHProjectionParams>();
287       if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
288         params->type = parseLSHProjectionType(lshParams->type());
289       }
290       builtin_data = reinterpret_cast<void*>(params);
291       break;
292     }
293     case BuiltinOperator_AVERAGE_POOL_2D:
294     case BuiltinOperator_MAX_POOL_2D:
295     case BuiltinOperator_L2_POOL_2D: {
296       TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
297       if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
298         params->padding = parse_padding(pool_params->padding());
299         params->stride_width = pool_params->stride_w();
300         params->stride_height = pool_params->stride_h();
301         params->filter_width = pool_params->filter_width();
302         params->filter_height = pool_params->filter_height();
303         params->activation =
304             parse_activation(pool_params->fused_activation_function());
305       }
306       builtin_data = reinterpret_cast<void*>(params);
307       break;
308     }
309     case BuiltinOperator_DEPTHWISE_CONV_2D: {
310       TfLiteDepthwiseConvParams* params =
311           MallocPOD<TfLiteDepthwiseConvParams>();
312       if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
313         params->padding = parse_padding(conv_params->padding());
314         params->stride_width = conv_params->stride_w();
315         params->stride_height = conv_params->stride_h();
316         params->depth_multiplier = conv_params->depth_multiplier();
317         params->activation =
318             parse_activation(conv_params->fused_activation_function());
319       }
320       builtin_data = reinterpret_cast<void*>(params);
321       break;
322     }
323     case BuiltinOperator_SVDF: {
324       TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
325       if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
326         params->rank = svdf_params->rank();
327         params->activation =
328             parse_activation(svdf_params->fused_activation_function());
329       }
330       builtin_data = reinterpret_cast<void*>(params);
331       break;
332     }
333     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
334     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
335       TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
336       if (auto* sequence_rnn_params =
337               op->builtin_options_as_SequenceRNNOptions()) {
338         params->activation =
339             parse_activation(sequence_rnn_params->fused_activation_function());
340         params->time_major = sequence_rnn_params->time_major();
341       }
342       builtin_data = reinterpret_cast<void*>(params);
343       break;
344     }
345     case BuiltinOperator_RNN: {
346       TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
347       if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
348         params->activation =
349             parse_activation(rnn_params->fused_activation_function());
350       }
351       builtin_data = reinterpret_cast<void*>(params);
352       break;
353     }
354     case BuiltinOperator_EMBEDDING_LOOKUP:
355       // no-op.
356       break;
357     case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
358       TfLiteEmbeddingLookupSparseParams* params =
359           MallocPOD<TfLiteEmbeddingLookupSparseParams>();
360       if (auto* embedding_params =
361               op->builtin_options_as_EmbeddingLookupSparseOptions()) {
362         params->combiner = parseCombinerType(embedding_params->combiner());
363       }
364       builtin_data = reinterpret_cast<void*>(params);
365       break;
366     }
367     case BuiltinOperator_FULLY_CONNECTED: {
368       TfLiteFullyConnectedParams* params =
369           MallocPOD<TfLiteFullyConnectedParams>();
370       if (auto* fully_connected_params =
371               op->builtin_options_as_FullyConnectedOptions()) {
372         params->activation = parse_activation(
373             fully_connected_params->fused_activation_function());
374       }
375       builtin_data = reinterpret_cast<void*>(params);
376       break;
377     }
378     case BuiltinOperator_HASHTABLE_LOOKUP:
379       // no-op.
380       break;
381     case BuiltinOperator_SOFTMAX: {
382       TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
383       if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
384         params->beta = softmax_params->beta();
385       }
386       builtin_data = reinterpret_cast<void*>(params);
387       break;
388     }
389     case BuiltinOperator_CONCATENATION: {
390       TfLiteConcatenationParams* params =
391           MallocPOD<TfLiteConcatenationParams>();
392       if (auto* concatenation_params =
393               op->builtin_options_as_ConcatenationOptions()) {
394         params->activation =
395             parse_activation(concatenation_params->fused_activation_function());
396         params->axis = concatenation_params->axis();
397       }
398       builtin_data = reinterpret_cast<void*>(params);
399       break;
400     }
401     case BuiltinOperator_MUL: {
402       auto* params = MallocPOD<TfLiteMulParams>();
403       if (auto* schema_params = op->builtin_options_as_MulOptions()) {
404         params->activation =
405             parse_activation(schema_params->fused_activation_function());
406       }
407       builtin_data = reinterpret_cast<void*>(params);
408       break;
409     }
410     case BuiltinOperator_ADD: {
411       auto* params = MallocPOD<TfLiteAddParams>();
412       if (auto* schema_params = op->builtin_options_as_AddOptions()) {
413         params->activation =
414             parse_activation(schema_params->fused_activation_function());
415       }
416       builtin_data = reinterpret_cast<void*>(params);
417       break;
418     }
419     case BuiltinOperator_DIV: {
420       auto* params = MallocPOD<TfLiteDivParams>();
421       if (auto* schema_params = op->builtin_options_as_DivOptions()) {
422         params->activation =
423             parse_activation(schema_params->fused_activation_function());
424       }
425       builtin_data = reinterpret_cast<void*>(params);
426       break;
427     }
428     case BuiltinOperator_SUB: {
429       auto* params = MallocPOD<TfLiteSubParams>();
430       if (auto* schema_params = op->builtin_options_as_SubOptions()) {
431         params->activation =
432             parse_activation(schema_params->fused_activation_function());
433       }
434       builtin_data = reinterpret_cast<void*>(params);
435       break;
436     }
437     case BuiltinOperator_L2_NORMALIZATION: {
438       auto* params = MallocPOD<TfLiteL2NormParams>();
439       if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
440         params->activation =
441             parse_activation(schema_params->fused_activation_function());
442       }
443       builtin_data = reinterpret_cast<void*>(params);
444       break;
445     }
446     case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
447       auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
448       if (auto* schema_params =
449               op->builtin_options_as_LocalResponseNormalizationOptions()) {
450         params->radius = schema_params->radius();
451         params->bias = schema_params->bias();
452         params->alpha = schema_params->alpha();
453         params->beta = schema_params->beta();
454       }
455       builtin_data = reinterpret_cast<void*>(params);
456       break;
457     }
458     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
459     case BuiltinOperator_LSTM: {
460       TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
461       if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
462         params->activation =
463             parse_activation(lstm_params->fused_activation_function());
464         params->cell_clip = lstm_params->cell_clip();
465         params->proj_clip = lstm_params->proj_clip();
466       }
467       builtin_data = reinterpret_cast<void*>(params);
468       break;
469     }
470     case BuiltinOperator_RESIZE_BILINEAR: {
471       auto* params = MallocPOD<TfLiteResizeBilinearParams>();
472       if (auto* schema_params =
473               op->builtin_options_as_ResizeBilinearOptions()) {
474         params->align_corners = schema_params->align_corners();
475       }
476       builtin_data = reinterpret_cast<void*>(params);
477       break;
478     }
479     case BuiltinOperator_PAD: {
480       break;
481     }
482     case BuiltinOperator_RESHAPE: {
483       auto* params = MallocPOD<TfLiteReshapeParams>();
484       if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
485         auto* new_shape = schema_params->new_shape();
486         FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
487                                    params->shape, error_reporter);
488         params->num_dimensions = new_shape->Length();
489       }
490       builtin_data = reinterpret_cast<void*>(params);
491       break;
492     }
493     case BuiltinOperator_SKIP_GRAM: {
494       TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
495       if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
496         params->ngram_size = skip_gram_params->ngram_size();
497         params->max_skip_size = skip_gram_params->max_skip_size();
498         params->include_all_ngrams = skip_gram_params->include_all_ngrams();
499       }
500       builtin_data = reinterpret_cast<void*>(params);
501       break;
502     }
503     case BuiltinOperator_SPACE_TO_DEPTH: {
504       auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
505       if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
506         params->block_size = schema_params->block_size();
507       }
508       builtin_data = reinterpret_cast<void*>(params);
509       break;
510     }
511     case BuiltinOperator_GATHER: {
512       TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
513       params->axis = 0;
514       if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
515         params->axis = gather_params->axis();
516       }
517 
518       builtin_data = reinterpret_cast<void*>(params);
519       break;
520     }
521     case BuiltinOperator_SPACE_TO_BATCH_ND: {
522       break;
523     }
524     case BuiltinOperator_BATCH_TO_SPACE_ND: {
525       break;
526     }
527     case BuiltinOperator_TRANSPOSE: {
528       break;
529     }
530     case BuiltinOperator_MEAN: {
531       auto* params = MallocPOD<TfLiteMeanParams>();
532       if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
533         params->keep_dims = schema_params->keep_dims();
534       }
535       builtin_data = reinterpret_cast<void*>(params);
536       break;
537     }
538     case BuiltinOperator_SPLIT: {
539       auto* params = MallocPOD<TfLiteSplitParams>();
540       if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
541         params->num_splits = schema_params->num_splits();
542       }
543       builtin_data = reinterpret_cast<void*>(params);
544       break;
545     }
546     case BuiltinOperator_SQUEEZE: {
547       auto* params = MallocPOD<TfLiteSqueezeParams>();
548       if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
549         const auto& squeeze_dims = schema_params->squeeze_dims();
550         FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
551                                    params->squeeze_dims, error_reporter);
552         params->num_squeeze_dims = squeeze_dims->Length();
553       }
554       builtin_data = reinterpret_cast<void*>(params);
555       break;
556     }
557     case BuiltinOperator_STRIDED_SLICE: {
558       auto* params = MallocPOD<TfLiteStridedSliceParams>();
559       if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
560         params->begin_mask = schema_params->begin_mask();
561         params->end_mask = schema_params->end_mask();
562         params->ellipsis_mask = schema_params->ellipsis_mask();
563         params->new_axis_mask = schema_params->new_axis_mask();
564         params->shrink_axis_mask = schema_params->shrink_axis_mask();
565       }
566       builtin_data = reinterpret_cast<void*>(params);
567       break;
568     }
569   }
570   return builtin_data;
571 }
572 
573 }  // namespace
574 
ParseNodes(const flatbuffers::Vector<flatbuffers::Offset<Operator>> * operators,Interpreter * interpreter)575 TfLiteStatus InterpreterBuilder::ParseNodes(
576     const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
577     Interpreter* interpreter) {
578   TfLiteStatus status = kTfLiteOk;
579   for (int i = 0; i < operators->Length(); ++i) {
580     const auto* op = operators->Get(i);
581     int index = op->opcode_index();
582     if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
583       error_reporter_->Report("Missing registration for opcode_index %d\n",
584                               index);
585       status = kTfLiteError;
586       continue;
587     }
588     const TfLiteRegistration* reg =
589         flatbuffer_op_index_to_registration_[op->opcode_index()];
590     if (reg == nullptr) {
591       error_reporter_->Report("Skipping op for opcode_index %d\n", index);
592       status = kTfLiteError;
593       continue;
594     }
595 
596     auto op_type =
597         flatbuffer_op_index_to_registration_types_[op->opcode_index()];
598     if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
599       error_reporter_->Report(
600           "Found builtin operator %s with custom options.\n",
601           EnumNameBuiltinOperator(op_type));
602     }
603     if (op->custom_options()) {
604       interpreter->AddNodeWithParameters(
605           FlatBufferIntArrayToVector(op->inputs()),
606           FlatBufferIntArrayToVector(op->outputs()),
607           reinterpret_cast<const char*>(op->custom_options()->data()),
608           op->custom_options()->size(), nullptr, reg);
609     } else {
610       interpreter->AddNodeWithParameters(
611           FlatBufferIntArrayToVector(op->inputs()),
612           FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
613           ParseOpData(op, op_type, error_reporter_), reg);
614     }
615   }
616 
617   return status;
618 }
619 
ParseTensors(const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * buffers,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,Interpreter * interpreter)620 TfLiteStatus InterpreterBuilder::ParseTensors(
621     const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
622     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
623     Interpreter* interpreter) {
624   TfLiteStatus status = kTfLiteOk;
625 
626   // A little helper to get the names of inputs and outputs. Note that they
627   // must outlive the interpreter.
628   auto get_name = [](const tflite::Tensor* t) -> const char* {
629     auto name = t->name();
630     if (name) return name->c_str();
631     return kEmptyTensorName;
632   };
633 
634   for (int i = 0; i < tensors->Length(); ++i) {
635     const auto* tensor = tensors->Get(i);
636     std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
637 
638     TfLiteQuantizationParams quantization;
639     quantization.scale = 0;
640     quantization.zero_point = 0;
641     auto* q_params = tensor->quantization();
642     if (q_params) {
643       // Note that the schema could hold per-channel quantization parameters
644       // but we really only support one value for the whole tensor.
645       // TODO(aselle): This breaks as well if these are nullptr's.
646       // TODO(aselle): This assumes non per-channel quantization.
647       if (q_params->scale()) quantization.scale = q_params->scale()->Get(0);
648       if (q_params->zero_point())
649         quantization.zero_point = q_params->zero_point()->Get(0);
650     }
651 
652     TfLiteType type;
653     switch (tensor->type()) {
654       case TensorType_FLOAT32:
655         type = kTfLiteFloat32;
656         break;
657       case TensorType_INT32:
658         type = kTfLiteInt32;
659         break;
660       case TensorType_UINT8:
661         type = kTfLiteUInt8;
662         break;
663       case TensorType_INT64:
664         type = kTfLiteInt64;
665         break;
666       case TensorType_STRING:
667         type = kTfLiteString;
668         break;
669       default:
670         // tensorType = ArrayType::NONE;
671         error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
672                                 EnumNameTensorType(tensor->type()),
673                                 tensor->type());
674         status = kTfLiteError;
675         continue;
676     }
677     auto get_readonly_data = [&](const char** buffer_data,
678                                  size_t* buffer_size) {
679       // TODO(aselle): Check what happens if we have an unspecified size
680       // constant.
681       *buffer_data = nullptr;
682       if (tensor->buffer() == 0) return kTfLiteOk;
683       if (tensor->buffer() >= buffers->size()) {
684         error_reporter_->Report(
685             "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
686             i, tensor->buffer(), buffers->size());
687         return kTfLiteError;
688       }
689       if (auto* buffer = (*buffers)[tensor->buffer()]) {
690         if (auto* array = buffer->data()) {
691           if (size_t size = array->size()) {
692             *buffer_size = size;
693             *buffer_data = reinterpret_cast<const char*>(array->data());
694             return kTfLiteOk;
695           }
696         }
697       }
698       return kTfLiteOk;
699     };
700     size_t buffer_size = 0;
701     const char* buffer_ptr;
702     TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
703 
704     if (buffer_ptr) {
705       if (interpreter->SetTensorParametersReadOnly(
706               i, type, get_name(tensor), dims, quantization, buffer_ptr,
707               buffer_size, allocation_) != kTfLiteOk) {
708         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
709                                 i);
710         status = kTfLiteError;
711       }
712     } else {
713       if (interpreter->SetTensorParametersReadWrite(
714               i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
715         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
716                                 i);
717         status = kTfLiteError;
718       }
719     }
720   }
721 
722   return status;
723 }
724 
operator ()(std::unique_ptr<Interpreter> * interpreter)725 TfLiteStatus InterpreterBuilder::operator()(
726     std::unique_ptr<Interpreter>* interpreter) {
727   if (!interpreter) {
728     error_reporter_->Report(
729         "Null output pointer passed to InterpreterBuilder.");
730     return kTfLiteError;
731   }
732 
733   // Safe exit by deleting partially created interpreter, to reduce verbosity
734   // on error conditions. Use by return cleanup_on_error();
735   auto cleanup_and_error = [&interpreter]() {
736     interpreter->reset();
737     return kTfLiteError;
738   };
739 
740   if (!model_) {
741     error_reporter_->Report("Null pointer passed in as model.");
742     return cleanup_and_error();
743   }
744 
745   if (model_->version() != TFLITE_SCHEMA_VERSION) {
746     error_reporter_->Report(
747         "Model provided is schema version %d not equal "
748         "to supported version %d.\n",
749         model_->version(), TFLITE_SCHEMA_VERSION);
750     return cleanup_and_error();
751   }
752 
753   if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
754     error_reporter_->Report("Registration failed.\n");
755     return cleanup_and_error();
756   }
757 
758   // Flatbuffer model schemas define a list of opcodes independent of the graph.
759   // We first map those to registrations. This reduces string lookups for custom
760   // ops since we only do it once per custom op rather than once per custom op
761   // invocation in the model graph.
762   // Construct interpreter with correct number of tensors and operators.
763   auto* subgraphs = model_->subgraphs();
764   auto* buffers = model_->buffers();
765   if (subgraphs->size() != 1) {
766     error_reporter_->Report("Only 1 subgraph is currently supported.\n");
767     return cleanup_and_error();
768   }
769   const tflite::SubGraph* subgraph = (*subgraphs)[0];
770   auto operators = subgraph->operators();
771   auto tensors = subgraph->tensors();
772   if (!operators || !tensors || !buffers) {
773     error_reporter_->Report(
774         "Did not get operators, tensors, or buffers in input flat buffer.\n");
775     return cleanup_and_error();
776   }
777   interpreter->reset(new Interpreter(error_reporter_));
778   if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
779     return cleanup_and_error();
780   }
781 
782   // Parse inputs/outputs
783   (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
784   (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
785 
786   // Finally setup nodes and tensors
787   if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
788     return cleanup_and_error();
789   if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
790     return cleanup_and_error();
791 
792   return kTfLiteOk;
793 }
794 
795 }  // namespace tflite
796