• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
18 #define COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
19 
20 #include "tensorflow/lite/interpreter.h"
21 #include "tensorflow/lite/model.h"
22 
23 #include <unistd.h>
24 #include <vector>
25 
26 struct InferenceOutput {
27   uint8_t* ptr;
28   size_t size;
29 };
30 
31 // Inputs and expected outputs for inference
32 struct InferenceInOut {
33   // Input can either be directly specified as a pointer or indirectly with
34   // the createInput callback. This is needed for large datasets where
35   // allocating memory for all inputs at once is not feasible.
36   uint8_t* input;
37   size_t input_size;
38 
39   std::vector<InferenceOutput> outputs;
40   std::function<bool(uint8_t*, size_t)> createInput;
41 };
42 
43 // Inputs and expected outputs for an inference sequence.
44 using InferenceInOutSequence = std::vector<InferenceInOut>;
45 
46 // Result of a single inference
47 struct InferenceResult {
48   float computeTimeSec;
49   // MSE for each output
50   std::vector<float> meanSquareErrors;
51   // Max single error for each output
52   std::vector<float> maxSingleErrors;
53   // Outputs
54   std::vector<std::vector<uint8_t>> inferenceOutputs;
55   int inputOutputSequenceIndex;
56   int inputOutputIndex;
57 };
58 
59 /** Discard inference output in inference results. */
60 const int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
61 /** Do not expect golden output for inference inputs. */
62 const int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
63 
64 class BenchmarkModel {
65  public:
66   ~BenchmarkModel();
67 
68   static BenchmarkModel* create(const char* modelfile, bool use_nnapi,
69                                 bool enable_intermediate_tensors_dump,
70                                 const char* nnapi_device_name = nullptr);
71 
72   bool resizeInputTensors(std::vector<int> shape);
73   bool setInput(const uint8_t* dataPtr, size_t length);
74   bool runInference();
75   // Resets TFLite states (RNN/LSTM states etc).
76   bool resetStates();
77 
78   bool benchmark(const std::vector<InferenceInOutSequence>& inOutData,
79                  int seqInferencesMaxCount, float timeout, int flags,
80                  std::vector<InferenceResult>* result);
81 
82   bool dumpAllLayers(const char* path,
83                      const std::vector<InferenceInOutSequence>& inOutData);
84 
85  private:
86   BenchmarkModel();
87   bool init(const char* modelfile, bool use_nnapi,
88             bool enable_intermediate_tensors_dump,
89             const char* nnapi_device_name);
90 
91   void getOutputError(const uint8_t* dataPtr, size_t length,
92                       InferenceResult* result, int output_index);
93   void saveInferenceOutput(InferenceResult* result, int output_index);
94 
95   std::unique_ptr<tflite::FlatBufferModel> mTfliteModel;
96   std::unique_ptr<tflite::Interpreter> mTfliteInterpreter;
97 };
98 
99 #endif  // COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
100