• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/compilation_stats.h"
17 
18 #include <iostream>
19 #include <memory>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/env.h"
27 
28 namespace xla {
29 
30 class NoopStats : public CompilationStats {
31  public:
32   NoopStats() = default;
33 
StartPass(absl::string_view pass_name)34   void StartPass(absl::string_view pass_name) override {}
35 
EndPass(absl::string_view pass_name)36   void EndPass(absl::string_view pass_name) override {}
37 
CompilationReport()38   void CompilationReport() override {}
39 };
40 
41 class Stats : public CompilationStats {
42  public:
43   Stats() = default;
44 
45   void StartPass(absl::string_view pass_name) override;
46 
47   void EndPass(absl::string_view pass_name) override;
48 
49   void CompilationReport() override;
50 
51  private:
52   struct PassInfo {
PassInfoxla::Stats::PassInfo53     PassInfo(absl::string_view name, double duration)
54         : name(name), duration_ms(duration) {}
55 
56     absl::string_view name;
57     int num_runs = 1;
58     double duration_ms;
59   };
60 
61   // Info about the passes that have been run so far.
62   std::vector<PassInfo> passes_;
63   // Used to avoid nested calls to StartPass.
64   bool pass_running_ = false;
65   absl::string_view current_pass_;
66   // The start time of the currently running pass.
67   uint64 start_micros_;
68 };
69 
70 /* static */
MakeNoopStats()71 std::unique_ptr<CompilationStats> CompilationStats::MakeNoopStats() {
72   return absl::make_unique<NoopStats>();
73 }
74 
75 /* static */
MakeStats()76 std::unique_ptr<CompilationStats> CompilationStats::MakeStats() {
77   return absl::make_unique<Stats>();
78 }
79 
StartPass(absl::string_view pass_name)80 void Stats::StartPass(absl::string_view pass_name) {
81   CHECK(!pass_running_) << "Can't start " << pass_name << " while running "
82                         << current_pass_;
83   pass_running_ = true;
84   current_pass_ = pass_name;
85   start_micros_ = tensorflow::Env::Default()->NowMicros();
86 }
87 
EndPass(absl::string_view pass_name)88 void Stats::EndPass(absl::string_view pass_name) {
89   CHECK(pass_running_);
90   CHECK_EQ(current_pass_, pass_name);
91   pass_running_ = false;
92   uint64 end_micros = tensorflow::Env::Default()->NowMicros();
93   double duration_ms = (end_micros - start_micros_) / 1000.0;
94   passes_.push_back(PassInfo(current_pass_, duration_ms));
95 }
96 
CompilationReport()97 void Stats::CompilationReport() {
98   CHECK(!pass_running_) << "EndPass never called for " << current_pass_;
99   absl::flat_hash_map<absl::string_view, PassInfo> summary;
100   double total_duration = 0;
101 
102   for (auto& pass_run : passes_) {
103     auto pass_name = pass_run.name;
104     total_duration += pass_run.duration_ms;
105     auto it = summary.find(pass_name);
106     if (it == summary.end()) {
107       summary.insert(std::make_pair(pass_name, pass_run));
108     } else {
109       ++summary.at(pass_name).num_runs;
110       summary.at(pass_name).duration_ms += pass_run.duration_ms;
111     }
112   }
113 
114   std::vector<PassInfo> sorted_summary;
115   sorted_summary.reserve(summary.size());
116   for (auto& it : summary) {
117     sorted_summary.push_back(it.second);
118   }
119   absl::c_sort(sorted_summary, [](const PassInfo& a, const PassInfo& b) {
120     // Sort passes that take the longest first, break ties using pass names.
121     return std::make_pair(b.duration_ms, a.name) <
122            std::make_pair(a.duration_ms, b.name);
123   });
124   LOG(INFO) << "Total runtime (ms) of HLO passes: " << total_duration;
125   LOG(INFO) << "Pass name, num runs, time (ms)";
126   for (auto& pass_info : sorted_summary) {
127     LOG(INFO) << pass_info.name << ", " << pass_info.num_runs << ", "
128               << pass_info.duration_ms;
129   }
130 }
131 
132 }  // namespace xla
133