• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PERFORMANCE_OPTIONS_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PERFORMANCE_OPTIONS_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "tensorflow/lite/tools/benchmark/benchmark_model.h"
25 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
26 
27 namespace tflite {
28 namespace benchmark {
29 
30 class MultiRunStatsRecorder : public BenchmarkListener {
31  public:
32   // BenchmarkListener::OnBenchmarkStart is invoked after each run's
33   // BenchmarkModel::Init. However, some run could fail during Init, e.g.
34   // delegate fails to be created etc. To still record such run, we will call
35   // the following function right before a run starts.
MarkBenchmarkStart(const BenchmarkParams & params)36   void MarkBenchmarkStart(const BenchmarkParams& params) {
37     results_.emplace_back(EachRunResult());
38     auto& current = results_.back();
39     current.completed = false;
40     current.params = absl::make_unique<BenchmarkParams>();
41     current.params->Merge(params, true /* overwrite*/);
42   }
43 
OnBenchmarkEnd(const BenchmarkResults & results)44   void OnBenchmarkEnd(const BenchmarkResults& results) final {
45     auto& current = results_.back();
46     current.completed = true;
47     current.metrics = results;
48   }
49 
50   virtual void OutputStats();
51 
52  protected:
53   struct EachRunResult {
54     bool completed = false;
55     std::unique_ptr<BenchmarkParams> params;
56     BenchmarkResults metrics;
57   };
58   std::vector<EachRunResult> results_;
59 
60   // Use this to order the runs by the average inference time in increasing
61   // order (i.e. the fastest run ranks first.). If the run didn't complete,
62   // we consider it to be slowest.
63   struct EachRunStatsEntryComparator {
operatorEachRunStatsEntryComparator64     bool operator()(const EachRunResult& i, const EachRunResult& j) {
65       if (!i.completed) return false;
66       if (!j.completed) return true;
67       return i.metrics.inference_time_us().avg() <
68              j.metrics.inference_time_us().avg();
69     }
70   };
71 
72   virtual std::string PerfOptionName(const BenchmarkParams& params) const;
73 };
74 
75 // Benchmarks all performance options on a model by repeatedly invoking the
76 // single-performance-option run on a passed-in 'BenchmarkModel' object.
77 class BenchmarkPerformanceOptions {
78  public:
79   // Doesn't own the memory of 'single_option_run'.
80   explicit BenchmarkPerformanceOptions(
81       BenchmarkModel* single_option_run,
82       std::unique_ptr<MultiRunStatsRecorder> all_run_stats =
83           absl::make_unique<MultiRunStatsRecorder>());
84 
~BenchmarkPerformanceOptions()85   virtual ~BenchmarkPerformanceOptions() {}
86 
87   // Just run the benchmark just w/ default parameter values.
88   void Run();
89   void Run(int argc, char** argv);
90 
91  protected:
92   static BenchmarkParams DefaultParams();
93 
94   BenchmarkPerformanceOptions(
95       BenchmarkParams params, BenchmarkModel* single_option_run,
96       std::unique_ptr<MultiRunStatsRecorder> all_run_stats);
97 
98   // Unparsable flags will remain in 'argv' in the original order and 'argc'
99   // will be updated accordingly.
100   bool ParseFlags(int* argc, char** argv);
101   virtual std::vector<Flag> GetFlags();
102 
103   bool ParsePerfOptions();
104   virtual std::vector<std::string> GetValidPerfOptions() const;
105   bool HasOption(const std::string& option) const;
106 
107   virtual void ResetPerformanceOptions();
108   virtual void CreatePerformanceOptions();
109 
110   BenchmarkParams params_;
111   std::vector<std::string> perf_options_;
112 
113   // The object that drives a single-performance-option run.
114   BenchmarkModel* const single_option_run_;          // Doesn't own the memory.
115   BenchmarkParams* const single_option_run_params_;  // Doesn't own the memory.
116 
117   // Each element is a set of performance-affecting benchmark parameters to be
118   // all set for a particular benchmark run.
119   std::vector<BenchmarkParams> all_run_params_;
120 
121   std::unique_ptr<MultiRunStatsRecorder> all_run_stats_;
122 };
123 
124 }  // namespace benchmark
125 }  // namespace tflite
126 
127 #endif  // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PERFORMANCE_OPTIONS_H_
128