1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
17 #include <algorithm>
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25
26 namespace xla {
27 namespace gpu {
28
AddDependenciesOnTransitiveOperands(const Thunk & thunk,const HloInstruction & operand,const absl::flat_hash_map<const HloInstruction *,Thunk * > & hlo_to_thunk)29 void ThunkSchedule::AddDependenciesOnTransitiveOperands(
30 const Thunk& thunk, const HloInstruction& operand,
31 const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
32 if (hlo_to_thunk.contains(&operand)) {
33 // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
34 // list if `operand` is assigned to a different stream. As an optimization,
35 // we skip `operand`'s operands because `operand` depends on them already.
36 if (stream_assignment_->StreamNumberForHlo(operand) !=
37 stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(&thunk))) {
38 depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand));
39 }
40 } else {
41 // If `operand` doesn't need a thunk (e.g. bitcast), continue with its
42 // operands.
43 for (const auto* operand_of_operand : operand.operands()) {
44 AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand,
45 hlo_to_thunk);
46 }
47 }
48 }
49
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,std::unique_ptr<StreamAssignment> stream_assignment,absl::flat_hash_map<const Thunk *,const HloInstruction * > thunk_to_hlo)50 ThunkSchedule::ThunkSchedule(
51 std::unique_ptr<ThunkSequence> thunks,
52 std::unique_ptr<StreamAssignment> stream_assignment,
53 absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo)
54 : thunks_(std::move(thunks)),
55 stream_assignment_(std::move(stream_assignment)),
56 thunk_to_hlo_(std::move(thunk_to_hlo)) {
57 for (auto& thunk : *thunks_) {
58 thunk_total_order_.push_back(thunk.get());
59 }
60
61 absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
62 for (const auto& thunk : *thunks_) {
63 InsertOrDie(&hlo_to_thunk, thunk_to_hlo_.at(thunk.get()), thunk.get());
64 }
65
66 for (const Thunk* thunk : thunk_total_order_) {
67 const auto* dst = thunk_to_hlo_.at(thunk);
68 CHECK(stream_assignment_->HasStreamAssigned(*dst));
69 for (const auto* src : dst->operands()) {
70 AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk);
71 }
72 }
73
74 RemoveRedundantDependencyEdges();
75
76 // Compute `depended_by_`, the inverse of `depends_on_`.
77 for (const auto& dependency : depends_on_) {
78 for (const auto* depended : dependency.second) {
79 depended_by_.insert(depended);
80 }
81 }
82 }
83
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks)84 ThunkSchedule::ThunkSchedule(std::unique_ptr<ThunkSequence> thunks)
85 : thunks_(std::move(thunks)) {
86 for (auto& thunk : *thunks_) {
87 thunk_total_order_.push_back(thunk.get());
88 }
89 }
90
RemoveRedundantDependencyEdges()91 void ThunkSchedule::RemoveRedundantDependencyEdges() {
92 std::unordered_map<const Thunk*, int> thunk_to_total_order;
93 for (int i = 0; i < thunk_total_order_.size(); ++i) {
94 InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
95 }
96
97 int stream_count = stream_assignment_->StreamCount();
98 // S1 S2
99 //
100 // T1<----+
101 // |
102 // T3<--+ |
103 // | | depends on
104 // T4 |
105 // |
106 // T2-+
107 //
108 // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on
109 // stream S2. If T2 depends on T1 and T4 depends on T3, and
110 // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is
111 // redundant.
112 //
113 // To efficiently detect such redundancy, we leverage array `last_dependency`.
114 // last_dependency[S1][S2] indicates the last thunk (with the maximum order
115 // number) on stream S2 that thunks on S1 depends on. Therefore, if a future
116 // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a
117 // redundant dependency edge.
118 Array2D<int> last_dependency(stream_count, stream_count, -1);
119 for (const Thunk* dst : thunk_total_order_) {
120 if (!depends_on_.contains(dst)) {
121 continue;
122 }
123
124 int dst_stream =
125 stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(dst));
126 std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst);
127 for (auto iter = sources.begin(); iter != sources.end();) {
128 const Thunk* src = *iter;
129 // `dst` depends on `src`.
130 int src_stream =
131 stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(src));
132 int src_order = FindOrDie(thunk_to_total_order, src);
133 if (src_order <= last_dependency(dst_stream, src_stream)) {
134 iter = sources.erase(iter);
135 } else {
136 last_dependency(dst_stream, src_stream) = src_order;
137 ++iter;
138 }
139 }
140 if (sources.empty()) {
141 depends_on_.erase(dst);
142 }
143 }
144 }
145
DependsOn(const Thunk * thunk) const146 const std::list<const Thunk*>& ThunkSchedule::DependsOn(
147 const Thunk* thunk) const {
148 if (depends_on_.contains(thunk)) {
149 return FindOrDie(depends_on_, thunk);
150 } else {
151 return empty_thunk_list_;
152 }
153 }
154
ToString() const155 string ThunkSchedule::ToString() const {
156 if (thunk_total_order_.empty()) {
157 return "No thunks.";
158 }
159
160 const Thunk* thunk_with_longest_kind = *absl::c_max_element(
161 thunk_total_order_, [](const Thunk* a, const Thunk* b) {
162 return ThunkKindToString(a->kind()).length() <
163 ThunkKindToString(b->kind()).length();
164 });
165 int64 max_thunk_kind_len =
166 ThunkKindToString(thunk_with_longest_kind->kind()).length();
167
168 string result = "Total order:\n";
169 for (Thunk* thunk : thunk_total_order_) {
170 // Write out the thunk kind, padded out to max_thunk_kind_len.
171 absl::string_view kind_str = ThunkKindToString(thunk->kind());
172 absl::StrAppend(&result, kind_str,
173 string(max_thunk_kind_len - kind_str.length(), ' '), "\t");
174 if (thunk_to_hlo_.at(thunk) != nullptr) {
175 absl::StrAppend(&result, thunk_to_hlo_.at(thunk)->ToString());
176 } else {
177 absl::StrAppend(&result, "(no HloInstruction)");
178 }
179 absl::StrAppend(&result, "\n");
180 }
181 absl::StrAppend(&result, "\nDependencies:\n");
182 for (const auto& entry : depends_on_) {
183 const Thunk* dependent = entry.first;
184 for (const Thunk* dependency : entry.second) {
185 absl::StrAppend(&result, "\t", thunk_to_hlo_.at(dependent)->name(),
186 " depends on ", thunk_to_hlo_.at(dependency)->name(),
187 "\n");
188 }
189 }
190 return result;
191 }
192
193 } // namespace gpu
194 } // namespace xla
195