• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_VALIDATOR_H_
16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_VALIDATOR_H_
17 
18 #include <cstdint>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/lite/core/subgraph.h"
26 #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
27 #include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
28 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h"
29 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/model_builder.h"
32 #include "tensorflow/lite/mutable_op_resolver.h"
33 
34 namespace tflite {
35 namespace acceleration {
36 
37 // Class to run the validation subgraph of a tflite model with embedded
38 // validation.
39 //
40 // The API is split into multiple steps so that callers can construct detailed
41 // telemetry from it.
42 class Validator {
43  public:
44   // Construct Validator for the given model and compute settings. The
45   // compute_settings must be valid for the lifetime of the Validator instance.
Validator(std::unique_ptr<ModelLoader> model_loader,const ComputeSettings * compute_settings)46   Validator(std::unique_ptr<ModelLoader> model_loader,
47             const ComputeSettings* compute_settings)
48       : model_loader_(std::move(model_loader)),
49         compute_settings_(compute_settings) {}
50 
51   // Results from validation.
52   struct Results {
53     // Are the results correct (metrics below threshold).
54     bool ok = false;
55     // What are the metrics results, for telemetry.
56     std::map<std::string, std::vector<float>> metrics;
57     // How long did loading the delegate and creating the interpreter take. -1
58     // if failed.
59     int64_t delegate_prep_time_us = 0;
60     // How long did execution (Invoke) take. (Empty in rare cases when reading
61     // the system clock fails).
62     std::vector<int64_t> execution_time_us;
63     // Any possible error from the delegate.
64     int delegate_error = 0;
65     // Number of delegated kernels.
66     int delegated_kernels = 0;
67     // Model output without the delegate.
68     // key: output tensor name.
69     // value: output tensor data in byte format.
70     std::map<std::string, std::vector<char>> golden_inference_output;
71     // Model output with the delegate.
72     // key: output tensor name;
73     // value: output tensor data in byte format.
74     std::map<std::string, std::vector<char>> actual_inference_output;
75   };
76 
77   // Run the validation graph and return validation results.
78   MinibenchmarkStatus RunValidation(Results* results_out);
79 
80   // Get timestamps.
81   static int64_t BootTimeMicros();
82   static int64_t WallTimeMicros();
83 
84   Validator(Validator&) = delete;
85   Validator& operator=(Validator&) = delete;
86   Validator(Validator&&) = delete;
87   Validator& operator=(Validator&&) = delete;
88 
89  private:
90   // Load delegate plugin and create delegate.
91   MinibenchmarkStatus LoadDelegate();
92 
93   // Create the interpreter with the delegate. Must be called after
94   // LoadDelegate().
95   MinibenchmarkStatus CreateInterpreter(int* delegate_error_out,
96                                         int* delegated_kernels_out);
97 
98   // Check if the golden output exists. If not, run Model on CPU and add golden
99   // output to model_. Also fills results_out with the golden output.
100   MinibenchmarkStatus CheckGoldenOutput(Results* results_out);
101 
102   std::unique_ptr<ModelLoader> model_loader_;
103   const ComputeSettings* compute_settings_;
104   // Interpreter that runs on CPU.
105   std::unique_ptr<Interpreter> golden_interpreter_;
106   // Interpreter that runs with delegate enabled, using the compute settings
107   // passed to the Validator constructor.
108   std::unique_ptr<Interpreter> interpreter_;
109   // Op resolver used to create the interpreters. Depending on the
110   // compute_settings_, it may or may not include the default delegate.
111   std::unique_ptr<::tflite::MutableOpResolver> resolver_;
112   std::unique_ptr<FlatBufferModel> model_;
113   ::tflite::delegates::TfLiteDelegatePtr delegate_ =
114       delegates::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
115   std::unique_ptr<tflite::delegates::DelegatePluginInterface> delegate_plugin_;
116   Subgraph* validation_entrypoint_ = nullptr;
117   Subgraph* main_model_ = nullptr;
118 };
119 
120 }  // namespace acceleration
121 }  // namespace tflite
122 
123 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_VALIDATOR_H_
124