1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
ExecutesBefore(const HloInstruction * a,const HloInstruction * b) const35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
36 const HloInstruction* b) const {
37 // 'a' and 'b' may be in different computations. In this case, find the
38 // callgraph ancestor instructions which call (potentially transitively) the
39 // computations containing 'a' and 'b' and use these ancestor instructions to
40 // compare order.
41 const HloInstruction* a_ancestor;
42 const HloInstruction* b_ancestor;
43 std::tie(a_ancestor, b_ancestor) =
44 call_graph_->NearestAncestorsInSameComputation(
45 const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
46
47 if (a_ancestor == nullptr) {
48 // Ancestors in a common computation could not be found so consider the
49 // instructions 'a' and 'b' to be unordered.
50 return false;
51 }
52 // a_ancestor and b_ancestor must be either both null or both non-null.
53 CHECK_NE(b_ancestor, nullptr);
54 CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
55
56 // If the common ancestor is a while instruction there is an additional
57 // ordering criteria which may apply. The condition computation is considered
58 // to execute before the body computation so if 'a' is in the condition and
59 // 'b' is in the body, then 'a' executes before 'b'.
60 if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
61 const HloComputation* body = a_ancestor->while_body();
62 const HloComputation* condition = a_ancestor->while_condition();
63 if (call_graph_->InstructionIsNestedIn(a, condition) &&
64 call_graph_->InstructionIsNestedIn(b, body)) {
65 return true;
66 }
67 }
68
69 // If the common ancestor is a conditional instruction, even though the branch
70 // computations are not really ordered per-se, we define the 0th branch
71 // computation to be ordered before the 1st one, before the 2nd and so forth.
72 // This ensures that buffers can still be shared among branch computations
73 // as they will forcibly have disjoint liveness.
74 if (a_ancestor == b_ancestor &&
75 (a_ancestor->opcode() == HloOpcode::kConditional)) {
76 int a_branch = -1;
77 int b_branch = -1;
78 for (int j = 0; j < a_ancestor->branch_count(); ++j) {
79 if (call_graph_->InstructionIsNestedIn(
80 a, a_ancestor->branch_computation(j))) {
81 a_branch = j;
82 }
83 if (call_graph_->InstructionIsNestedIn(
84 b, a_ancestor->branch_computation(j))) {
85 b_branch = j;
86 }
87 }
88 if (a_branch != -1 && a_branch < b_branch) {
89 return true;
90 }
91 // If 'b' is the conditional ancestor, and 'a' is within a branch
92 // computation, 'a' executes before 'b'.
93 if (b == a_ancestor && a_branch != -1) {
94 return true;
95 }
96 }
97
98 return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
99 }
100
IsDefinedBefore(const HloValue & a,const HloValue & b) const101 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
102 // Entry parameter should always be defined before other instructions.
103 const HloModule* module = b.defining_instruction()->parent()->parent();
104 if (b.defining_instruction()->parent() == module->entry_computation() &&
105 b.defining_instruction()->opcode() == HloOpcode::kParameter) {
106 return false;
107 }
108
109 if (a.defining_instruction()->parent() == module->entry_computation() &&
110 a.defining_instruction()->opcode() == HloOpcode::kParameter) {
111 return true;
112 }
113
114 // Phi values require special handling. Because XLA does not have a phi
115 // instruction, the definition instruction of the phis values are
116 // placeholders: either the subcomputation parameter (body or condition) or
117 // the while instruction. However, the program point where these values are
118 // logically defined does not necessarily coincide exactly with program point
119 // of these place-holder instructions. So we explicitly define the following
120 // order for phi values:
121 //
122 // body/condition parameter phi:
123 // Defined before all values defined in its computation excepting other
124 // phis.
125 //
126 // while phi:
127 // defined after all values defined in the condition or body.
128 //
129 auto is_body_or_condition_phi = [](const HloValue& v) {
130 return v.is_phi() &&
131 v.defining_instruction()->opcode() == HloOpcode::kParameter;
132 };
133 if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
134 call_graph_->InstructionIsNestedIn(b.defining_instruction(),
135 a.defining_instruction()->parent())) {
136 return true;
137 }
138 if (is_body_or_condition_phi(b) &&
139 call_graph_->InstructionIsNestedIn(a.defining_instruction(),
140 b.defining_instruction()->parent())) {
141 return false;
142 }
143
144 // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
145 // executes before 'b'.
146 if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
147 (call_graph_->InstructionIsNestedIn(
148 a.defining_instruction(), b.defining_instruction()->while_body()) ||
149 call_graph_->InstructionIsNestedIn(
150 a.defining_instruction(),
151 b.defining_instruction()->while_condition()))) {
152 return true;
153 }
154 // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
155 // executes before 'b'.
156 if (b.is_phi() &&
157 b.defining_instruction()->opcode() == HloOpcode::kConditional) {
158 for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
159 if (call_graph_->InstructionIsNestedIn(
160 a.defining_instruction(),
161 b.defining_instruction()->branch_computation(j))) {
162 return true;
163 }
164 }
165 }
166 return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
167 }
168
169 /* static */
UseIsBeforeValueDefinition(const HloUse & use,const HloValue & value,const HloDataflowAnalysis & dataflow) const170 bool HloOrdering::UseIsBeforeValueDefinition(
171 const HloUse& use, const HloValue& value,
172 const HloDataflowAnalysis& dataflow) const {
173 VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
174 << ", value=" << value.ToShortString() << ")";
175 if (ExecutesBefore(use.instruction, value.defining_instruction())) {
176 VLOG(4) << " use instruction executes before value-defining instruction";
177 return true;
178 }
179
180 // If the use is at the instruction where the value is defined, then the use
181 // is before the def if the instruction allows buffer sharing (in place
182 // computation).
183 if (use.instruction == value.defining_instruction() &&
184 dataflow.CanShareOperandBufferWithUser(
185 use.instruction->mutable_operand(use.operand_number),
186 use.operand_index, value.defining_instruction(),
187 value.defining_index())) {
188 VLOG(4) << " use is value def, and instruction can share use buffer";
189 return true;
190 }
191
192 // The use at a while is an input to a phi, and logically occurs before values
193 // are defined in the body or condition computations.
194 if (use.instruction->opcode() == HloOpcode::kWhile) {
195 const HloInstruction* xla_while = use.instruction;
196 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
197 xla_while->while_body()) ||
198 call_graph_->InstructionIsNestedIn(value.defining_instruction(),
199 xla_while->while_condition())) {
200 VLOG(4) << " use is while " << use.instruction->name()
201 << " and def is in condition or body";
202 return true;
203 }
204 }
205
206 // Similarly if the value is defined at a while, it logically occurs after any
207 // uses in the body or condition computations.
208 if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
209 CHECK(value.is_phi());
210 const HloInstruction* xla_while = value.defining_instruction();
211 if (call_graph_->InstructionIsNestedIn(use.instruction,
212 xla_while->while_body()) ||
213 call_graph_->InstructionIsNestedIn(use.instruction,
214 xla_while->while_condition())) {
215 VLOG(4) << " value is while " << value.defining_instruction()->name()
216 << " and use is in condition or body";
217 return true;
218 }
219 }
220
221 // The use at a call occurs before values that are defined in the called
222 // computation.
223 if (use.instruction->opcode() == HloOpcode::kCall) {
224 const HloInstruction* call = use.instruction;
225 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
226 call->to_apply())) {
227 VLOG(4) << " use is call " << use.instruction->name()
228 << " and def is in called computation";
229 return true;
230 }
231 }
232
233 if (use.instruction->opcode() == HloOpcode::kConditional) {
234 const HloInstruction* conditional = use.instruction;
235 for (int j = 0; j < conditional->branch_count(); ++j) {
236 if (call_graph_->InstructionIsNestedIn(
237 value.defining_instruction(),
238 conditional->branch_computation(j))) {
239 VLOG(4) << " use is conditional " << use.instruction->name()
240 << " and def is in " << j << "th branch computation";
241 return true;
242 }
243 }
244 if (value.defining_instruction() == use.instruction) {
245 VLOG(4) << " use is conditional " << use << " and def is "
246 << value.ToShortString();
247 return true;
248 }
249 }
250
251 VLOG(4) << " use is not before value";
252 return false;
253 }
254
LiveRangeStrictlyBefore(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const255 bool HloOrdering::LiveRangeStrictlyBefore(
256 const HloValue& a, const HloValue& b,
257 const HloDataflowAnalysis& dataflow) const {
258 VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
259 << ", b = " << b.ToShortString() << ")";
260 if (!IsDefinedBefore(a, b)) {
261 VLOG(4) << a << " not defined before " << b;
262 return false;
263 }
264
265 if (a.live_out_of_module()) {
266 VLOG(4) << a << " is live out of module and defined before " << b;
267 return false;
268 }
269
270 // All uses of 'a' must be before 'b' is defined.
271 for (const HloUse& use : a.uses()) {
272 if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
273 use.instruction)) {
274 continue;
275 }
276 if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
277 VLOG(4) << "use of " << a << " (" << use << ") not before " << b
278 << " is defined";
279 return false;
280 }
281 }
282
283 if (a.instruction()->parent() == b.instruction()->parent()) {
284 for (const HloPosition& position : a.positions()) {
285 if (position.instruction ==
286 a.instruction()->parent()->root_instruction()) {
287 VLOG(4) << a << " is live out of computation and defined before " << b
288 << " which is in same computation";
289 return false;
290 }
291 }
292 }
293
294 return true;
295 }
296
MayInterfere(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const297 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
298 const HloDataflowAnalysis& dataflow) const {
299 // Buffers without disjoint liveness may interfere.
300 return !LiveRangeStrictlyBefore(a, b, dataflow) &&
301 !LiveRangeStrictlyBefore(b, a, dataflow);
302 }
303
PredecessorHloOrdering(const HloModule * module)304 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
305 : HloOrdering(module) {}
306
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const307 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
308 const HloInstruction* a, const HloInstruction* b) const {
309 CHECK_EQ(a->parent(), b->parent());
310
311 // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
312 return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
313 }
314
ToStringHelper(const string & name) const315 string PredecessorHloOrdering::ToStringHelper(const string& name) const {
316 std::vector<string> pieces;
317 pieces.push_back(name);
318 for (auto* computation : module_->MakeNonfusionComputations()) {
319 pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
320 const auto all = computation->MakeInstructionPostOrder();
321 for (auto instruction : all) {
322 pieces.push_back(
323 absl::StrFormat(" %s predecessors:", instruction->name()));
324 for (auto predecessor : all) {
325 if (predecessors_.at(computation)
326 ->IsReachable(predecessor, instruction)) {
327 pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
328 }
329 }
330 }
331 }
332 return absl::StrJoin(pieces, "\n");
333 }
334
DependencyHloOrdering(const HloModule * module)335 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
336 : PredecessorHloOrdering(module) {
337 // Compute predecessor relationships between all instructions to determine
338 // ordering based on dependencies. ExecutesBefore will return true iff there
339 // exists a path in the HLO computation graph from 'a' to 'b'.
340 for (auto* computation : module->MakeNonfusionComputations()) {
341 predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
342 }
343 }
344
ToString() const345 string DependencyHloOrdering::ToString() const {
346 return ToStringHelper("DependencyHloOrdering");
347 }
348
SequentialHloOrdering(const HloSchedule & schedule)349 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
350 : HloOrdering(schedule.module()), schedule_(schedule) {
351 Initialize();
352 }
353
SequentialHloOrdering(HloSchedule && schedule)354 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
355 : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
356 Initialize();
357 }
358
Initialize()359 void SequentialHloOrdering::Initialize() {
360 // Create a map from instruction to its order position.
361 TF_DCHECK_OK(schedule_.Verify());
362 for (const auto& computation_sequence : schedule_.sequences()) {
363 const auto& order = computation_sequence.second.instructions();
364 for (int i = 0; i < order.size(); ++i) {
365 InsertOrDie(&order_position_, order[i], i);
366 }
367 }
368 }
369
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const370 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
371 const HloInstruction* a, const HloInstruction* b) const {
372 CHECK_EQ(a->parent(), b->parent());
373 // If either instruction is not in the order, then 'a' and 'b' are unordered.
374 if (!order_position_.contains(a) || !order_position_.contains(b)) {
375 return false;
376 }
377 return order_position_.at(a) < order_position_.at(b);
378 }
379
SequentialOrder(const HloComputation & computation) const380 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
381 const HloComputation& computation) const {
382 return schedule_.is_computation_scheduled(&computation)
383 ? &schedule_.sequence(&computation)
384 : nullptr;
385 }
386
ToString() const387 string SequentialHloOrdering::ToString() const {
388 return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
389 }
390
391 } // namespace xla
392