1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/compilation_stats.h"
17
18 #include <iostream>
19 #include <memory>
20 #include <string>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/env.h"
27
28 namespace xla {
29
30 class NoopStats : public CompilationStats {
31 public:
32 NoopStats() = default;
33
StartPass(absl::string_view pass_name)34 void StartPass(absl::string_view pass_name) override {}
35
EndPass(absl::string_view pass_name)36 void EndPass(absl::string_view pass_name) override {}
37
CompilationReport()38 void CompilationReport() override {}
39 };
40
41 class Stats : public CompilationStats {
42 public:
43 Stats() = default;
44
45 void StartPass(absl::string_view pass_name) override;
46
47 void EndPass(absl::string_view pass_name) override;
48
49 void CompilationReport() override;
50
51 private:
52 struct PassInfo {
PassInfoxla::Stats::PassInfo53 PassInfo(absl::string_view name, double duration)
54 : name(name), duration_ms(duration) {}
55
56 absl::string_view name;
57 int num_runs = 1;
58 double duration_ms;
59 };
60
61 // Info about the passes that have been run so far.
62 std::vector<PassInfo> passes_;
63 // Used to avoid nested calls to StartPass.
64 bool pass_running_ = false;
65 absl::string_view current_pass_;
66 // The start time of the currently running pass.
67 uint64 start_micros_;
68 };
69
70 /* static */
MakeNoopStats()71 std::unique_ptr<CompilationStats> CompilationStats::MakeNoopStats() {
72 return absl::make_unique<NoopStats>();
73 }
74
75 /* static */
MakeStats()76 std::unique_ptr<CompilationStats> CompilationStats::MakeStats() {
77 return absl::make_unique<Stats>();
78 }
79
StartPass(absl::string_view pass_name)80 void Stats::StartPass(absl::string_view pass_name) {
81 CHECK(!pass_running_) << "Can't start " << pass_name << " while running "
82 << current_pass_;
83 pass_running_ = true;
84 current_pass_ = pass_name;
85 start_micros_ = tensorflow::Env::Default()->NowMicros();
86 }
87
EndPass(absl::string_view pass_name)88 void Stats::EndPass(absl::string_view pass_name) {
89 CHECK(pass_running_);
90 CHECK_EQ(current_pass_, pass_name);
91 pass_running_ = false;
92 uint64 end_micros = tensorflow::Env::Default()->NowMicros();
93 double duration_ms = (end_micros - start_micros_) / 1000.0;
94 passes_.push_back(PassInfo(current_pass_, duration_ms));
95 }
96
CompilationReport()97 void Stats::CompilationReport() {
98 CHECK(!pass_running_) << "EndPass never called for " << current_pass_;
99 absl::flat_hash_map<absl::string_view, PassInfo> summary;
100 double total_duration = 0;
101
102 for (auto& pass_run : passes_) {
103 auto pass_name = pass_run.name;
104 total_duration += pass_run.duration_ms;
105 auto it = summary.find(pass_name);
106 if (it == summary.end()) {
107 summary.insert(std::make_pair(pass_name, pass_run));
108 } else {
109 ++summary.at(pass_name).num_runs;
110 summary.at(pass_name).duration_ms += pass_run.duration_ms;
111 }
112 }
113
114 std::vector<PassInfo> sorted_summary;
115 sorted_summary.reserve(summary.size());
116 for (auto& it : summary) {
117 sorted_summary.push_back(it.second);
118 }
119 absl::c_sort(sorted_summary, [](const PassInfo& a, const PassInfo& b) {
120 // Sort passes that take the longest first, break ties using pass names.
121 return std::make_pair(b.duration_ms, a.name) <
122 std::make_pair(a.duration_ms, b.name);
123 });
124 LOG(INFO) << "Total runtime (ms) of HLO passes: " << total_duration;
125 LOG(INFO) << "Pass name, num runs, time (ms)";
126 for (auto& pass_info : sorted_summary) {
127 LOG(INFO) << pass_info.name << ", " << pass_info.num_runs << ", "
128 << pass_info.duration_ms;
129 }
130 }
131
132 } // namespace xla
133