• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/benchmark/profiling_listener.h"
17 
18 #include <fstream>
19 #include <string>
20 
21 #include "tensorflow/lite/tools/logging.h"
22 
23 namespace tflite {
24 namespace benchmark {
25 
ProfilingListener(Interpreter * interpreter,uint32_t max_num_initial_entries,bool allow_dynamic_buffer_increase,const std::string & csv_file_path,std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter)26 ProfilingListener::ProfilingListener(
27     Interpreter* interpreter, uint32_t max_num_initial_entries,
28     bool allow_dynamic_buffer_increase, const std::string& csv_file_path,
29     std::shared_ptr<profiling::ProfileSummaryFormatter> summarizer_formatter)
30     : run_summarizer_(summarizer_formatter),
31       init_summarizer_(summarizer_formatter),
32       csv_file_path_(csv_file_path),
33       interpreter_(interpreter),
34       profiler_(max_num_initial_entries, allow_dynamic_buffer_increase) {
35   TFLITE_TOOLS_CHECK(interpreter);
36   interpreter_->SetProfiler(&profiler_);
37 
38   // We start profiling here in order to catch events that are recorded during
39   // the benchmark run preparation stage where TFLite interpreter is
40   // initialized and model graph is prepared.
41   profiler_.Reset();
42   profiler_.StartProfiling();
43 }
44 
OnBenchmarkStart(const BenchmarkParams & params)45 void ProfilingListener::OnBenchmarkStart(const BenchmarkParams& params) {
46   // At this point, we have completed the preparation for benchmark runs
47   // including TFLite interpreter initialization etc. So we are going to process
48   // profiling events recorded during this stage.
49   profiler_.StopProfiling();
50   auto profile_events = profiler_.GetProfileEvents();
51   init_summarizer_.ProcessProfiles(profile_events, *interpreter_);
52   profiler_.Reset();
53 }
54 
OnSingleRunStart(RunType run_type)55 void ProfilingListener::OnSingleRunStart(RunType run_type) {
56   if (run_type == REGULAR) {
57     profiler_.Reset();
58     profiler_.StartProfiling();
59   }
60 }
61 
OnSingleRunEnd()62 void ProfilingListener::OnSingleRunEnd() {
63   profiler_.StopProfiling();
64   auto profile_events = profiler_.GetProfileEvents();
65   run_summarizer_.ProcessProfiles(profile_events, *interpreter_);
66 }
67 
OnBenchmarkEnd(const BenchmarkResults & results)68 void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) {
69   std::ofstream output_file(csv_file_path_);
70   std::ostream* output_stream = nullptr;
71   if (output_file.good()) {
72     output_stream = &output_file;
73   }
74   if (init_summarizer_.HasProfiles()) {
75     WriteOutput("Profiling Info for Benchmark Initialization:",
76                 init_summarizer_.GetOutputString(),
77                 output_stream == nullptr ? &TFLITE_LOG(INFO) : output_stream);
78   }
79   if (run_summarizer_.HasProfiles()) {
80     WriteOutput("Operator-wise Profiling Info for Regular Benchmark Runs:",
81                 run_summarizer_.GetOutputString(),
82                 output_stream == nullptr ? &TFLITE_LOG(INFO) : output_stream);
83   }
84 }
85 
WriteOutput(const std::string & header,const string & data,std::ostream * stream)86 void ProfilingListener::WriteOutput(const std::string& header,
87                                     const string& data, std::ostream* stream) {
88   (*stream) << header << std::endl;
89   (*stream) << data << std::endl;
90 }
91 
92 }  // namespace benchmark
93 }  // namespace tflite
94