• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
17 #include <algorithm>
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 
26 namespace xla {
27 namespace gpu {
28 
AddDependenciesOnTransitiveOperands(const Thunk & thunk,const HloInstruction & operand,const absl::flat_hash_map<const HloInstruction *,Thunk * > & hlo_to_thunk)29 void ThunkSchedule::AddDependenciesOnTransitiveOperands(
30     const Thunk& thunk, const HloInstruction& operand,
31     const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
32   if (hlo_to_thunk.contains(&operand)) {
33     // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
34     // list if `operand` is assigned to a different stream. As an optimization,
35     // we skip `operand`'s operands because `operand` depends on them already.
36     if (stream_assignment_->StreamNumberForHlo(operand) !=
37         stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(&thunk))) {
38       depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand));
39     }
40   } else {
41     // If `operand` doesn't need a thunk (e.g. bitcast), continue with its
42     // operands.
43     for (const auto* operand_of_operand : operand.operands()) {
44       AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand,
45                                           hlo_to_thunk);
46     }
47   }
48 }
49 
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,std::unique_ptr<StreamAssignment> stream_assignment,absl::flat_hash_map<const Thunk *,const HloInstruction * > thunk_to_hlo)50 ThunkSchedule::ThunkSchedule(
51     std::unique_ptr<ThunkSequence> thunks,
52     std::unique_ptr<StreamAssignment> stream_assignment,
53     absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo)
54     : thunks_(std::move(thunks)),
55       stream_assignment_(std::move(stream_assignment)),
56       thunk_to_hlo_(std::move(thunk_to_hlo)) {
57   for (auto& thunk : *thunks_) {
58     thunk_total_order_.push_back(thunk.get());
59   }
60 
61   absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
62   for (const auto& thunk : *thunks_) {
63     InsertOrDie(&hlo_to_thunk, thunk_to_hlo_.at(thunk.get()), thunk.get());
64   }
65 
66   for (const Thunk* thunk : thunk_total_order_) {
67     const auto* dst = thunk_to_hlo_.at(thunk);
68     CHECK(stream_assignment_->HasStreamAssigned(*dst));
69     for (const auto* src : dst->operands()) {
70       AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk);
71     }
72   }
73 
74   RemoveRedundantDependencyEdges();
75 
76   // Compute `depended_by_`, the inverse of `depends_on_`.
77   for (const auto& dependency : depends_on_) {
78     for (const auto* depended : dependency.second) {
79       depended_by_.insert(depended);
80     }
81   }
82 }
83 
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks)84 ThunkSchedule::ThunkSchedule(std::unique_ptr<ThunkSequence> thunks)
85     : thunks_(std::move(thunks)) {
86   for (auto& thunk : *thunks_) {
87     thunk_total_order_.push_back(thunk.get());
88   }
89 }
90 
RemoveRedundantDependencyEdges()91 void ThunkSchedule::RemoveRedundantDependencyEdges() {
92   std::unordered_map<const Thunk*, int> thunk_to_total_order;
93   for (int i = 0; i < thunk_total_order_.size(); ++i) {
94     InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
95   }
96 
97   int stream_count = stream_assignment_->StreamCount();
98   // S1  S2
99   //
100   // T1<----+
101   //        |
102   // T3<--+ |
103   //      | | depends on
104   //     T4 |
105   //        |
106   //     T2-+
107   //
108   // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on
109   // stream S2. If T2 depends on T1 and T4 depends on T3, and
110   // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is
111   // redundant.
112   //
113   // To efficiently detect such redundancy, we leverage array `last_dependency`.
114   // last_dependency[S1][S2] indicates the last thunk (with the maximum order
115   // number) on stream S2 that thunks on S1 depends on. Therefore, if a future
116   // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a
117   // redundant dependency edge.
118   Array2D<int> last_dependency(stream_count, stream_count, -1);
119   for (const Thunk* dst : thunk_total_order_) {
120     if (!depends_on_.contains(dst)) {
121       continue;
122     }
123 
124     int dst_stream =
125         stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(dst));
126     std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst);
127     for (auto iter = sources.begin(); iter != sources.end();) {
128       const Thunk* src = *iter;
129       // `dst` depends on `src`.
130       int src_stream =
131           stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(src));
132       int src_order = FindOrDie(thunk_to_total_order, src);
133       if (src_order <= last_dependency(dst_stream, src_stream)) {
134         iter = sources.erase(iter);
135       } else {
136         last_dependency(dst_stream, src_stream) = src_order;
137         ++iter;
138       }
139     }
140     if (sources.empty()) {
141       depends_on_.erase(dst);
142     }
143   }
144 }
145 
DependsOn(const Thunk * thunk) const146 const std::list<const Thunk*>& ThunkSchedule::DependsOn(
147     const Thunk* thunk) const {
148   if (depends_on_.contains(thunk)) {
149     return FindOrDie(depends_on_, thunk);
150   } else {
151     return empty_thunk_list_;
152   }
153 }
154 
ToString() const155 string ThunkSchedule::ToString() const {
156   if (thunk_total_order_.empty()) {
157     return "No thunks.";
158   }
159 
160   const Thunk* thunk_with_longest_kind = *absl::c_max_element(
161       thunk_total_order_, [](const Thunk* a, const Thunk* b) {
162         return ThunkKindToString(a->kind()).length() <
163                ThunkKindToString(b->kind()).length();
164       });
165   int64 max_thunk_kind_len =
166       ThunkKindToString(thunk_with_longest_kind->kind()).length();
167 
168   string result = "Total order:\n";
169   for (Thunk* thunk : thunk_total_order_) {
170     // Write out the thunk kind, padded out to max_thunk_kind_len.
171     absl::string_view kind_str = ThunkKindToString(thunk->kind());
172     absl::StrAppend(&result, kind_str,
173                     string(max_thunk_kind_len - kind_str.length(), ' '), "\t");
174     if (thunk_to_hlo_.at(thunk) != nullptr) {
175       absl::StrAppend(&result, thunk_to_hlo_.at(thunk)->ToString());
176     } else {
177       absl::StrAppend(&result, "(no HloInstruction)");
178     }
179     absl::StrAppend(&result, "\n");
180   }
181   absl::StrAppend(&result, "\nDependencies:\n");
182   for (const auto& entry : depends_on_) {
183     const Thunk* dependent = entry.first;
184     for (const Thunk* dependency : entry.second) {
185       absl::StrAppend(&result, "\t", thunk_to_hlo_.at(dependent)->name(),
186                       " depends on ", thunk_to_hlo_.at(dependency)->name(),
187                       "\n");
188     }
189   }
190   return result;
191 }
192 
193 }  // namespace gpu
194 }  // namespace xla
195