1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
ExecutesBefore(const HloInstruction * a,const HloInstruction * b) const35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
36 const HloInstruction* b) const {
37 switch (GetExecutionConstraint(a, b)) {
38 case ExecutionConstraint::kIsSame: // a and b are the same instruction;
39 return false;
40 case ExecutionConstraint::kRunBeforeStart:
41 case ExecutionConstraint::kRunBeforeEnd:
42 case ExecutionConstraint::kRunExclusiveBefore:
43 return true;
44 case ExecutionConstraint::kRunExclusiveAfter:
45 case ExecutionConstraint::kRunAfter:
46 case ExecutionConstraint::kUnordered:
47 return false;
48 }
49 }
50
GetExecutionConstraint(const HloInstruction * a,const HloInstruction * b) const51 HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint(
52 const HloInstruction* a, const HloInstruction* b) const {
53 // 'a' and 'b' may be in different computations. In this case, find the
54 // callgraph ancestor instructions which call (potentially transitively) the
55 // computations containing 'a' and 'b' and use these ancestor instructions to
56 // compare order.
57 if (a == b) {
58 return ExecutionConstraint::kIsSame;
59 }
60 const HloInstruction* a_ancestor;
61 const HloInstruction* b_ancestor;
62 std::tie(a_ancestor, b_ancestor) =
63 call_graph_->NearestAncestorsInSameComputation(
64 const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
65
66 if (a_ancestor == nullptr) {
67 VLOG(4) << "Ancestors in a common computation could not be found between"
68 << a->ToString() << "\n and \n"
69 << b->ToString() << "\n so consider them to be unordered.\n";
70 return ExecutionConstraint::kUnordered;
71 }
72 // a_ancestor and b_ancestor must be either both null or both non-null.
73 CHECK_NE(b_ancestor, nullptr);
74 CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
75
76 // If the common ancestor is a while instruction there is an additional
77 // ordering criteria which may apply. The condition computation is considered
78 // to execute before the body computation so if 'a' is in the condition and
79 // 'b' is in the body, then 'a' executes before 'b'.
80 if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
81 const HloComputation* body = a_ancestor->while_body();
82 const HloComputation* condition = a_ancestor->while_condition();
83 if (call_graph_->InstructionIsNestedIn(a, condition) &&
84 call_graph_->InstructionIsNestedIn(b, body)) {
85 return ExecutionConstraint::kRunBeforeEnd;
86 }
87 }
88
89 // If the common ancestor is a conditional instruction, even though the branch
90 // computations are not really ordered per-se, we define the 0th branch
91 // computation to be ordered before the 1st one, before the 2nd and so forth.
92 // This ensures that buffers can still be shared among branch computations
93 // as they will forcibly have disjoint liveness.
94 if (a_ancestor == b_ancestor &&
95 (a_ancestor->opcode() == HloOpcode::kConditional)) {
96 int a_branch = -1;
97 int b_branch = -1;
98 for (int j = 0; j < a_ancestor->branch_count(); ++j) {
99 if (call_graph_->InstructionIsNestedIn(
100 a, a_ancestor->branch_computation(j))) {
101 a_branch = j;
102 }
103 if (call_graph_->InstructionIsNestedIn(
104 b, a_ancestor->branch_computation(j))) {
105 b_branch = j;
106 }
107 }
108 // If neither a nor b is inside the branches they both are the ancestor.
109 if (a_branch == -1 && b_branch == -1) {
110 CHECK_EQ(a, a_ancestor);
111 CHECK_EQ(b, b_ancestor);
112 CHECK_EQ(a, b);
113 return ExecutionConstraint::kIsSame;
114 }
115 // If 'b' is the conditional ancestor, and 'a' is within a branch
116 // computation, 'a' executes before 'b'.
117 if (b_branch == -1) {
118 CHECK_EQ(b, a_ancestor);
119 return ExecutionConstraint::kRunBeforeEnd;
120 }
121 if (a_branch == -1) {
122 CHECK_EQ(a, a_ancestor);
123 return ExecutionConstraint::kRunAfter;
124 }
125 if (a_branch < b_branch) {
126 return ExecutionConstraint::kRunExclusiveBefore;
127 }
128 if (b_branch < a_branch) {
129 return ExecutionConstraint::kRunExclusiveAfter;
130 }
131 }
132
133 if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) {
134 return ExecutionConstraint::kRunBeforeStart;
135 }
136 if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) {
137 return ExecutionConstraint::kRunAfter;
138 }
139 VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n"
140 << "and " << b->ToString() << " which are in the same computation\n";
141 return ExecutionConstraint::kUnordered;
142 }
143
IsDefinedBefore(const HloValue & a,const HloValue & b) const144 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
145 // Entry parameter should always be defined before other instructions.
146 const HloModule* module = b.defining_instruction()->parent()->parent();
147 if (b.defining_instruction()->parent() == module->entry_computation() &&
148 b.defining_instruction()->opcode() == HloOpcode::kParameter) {
149 return false;
150 }
151
152 if (a.defining_instruction()->parent() == module->entry_computation() &&
153 a.defining_instruction()->opcode() == HloOpcode::kParameter) {
154 return true;
155 }
156
157 // Phi values require special handling. Because XLA does not have a phi
158 // instruction, the definition instruction of the phis values are
159 // placeholders: either the subcomputation parameter (body or condition) or
160 // the while instruction. However, the program point where these values are
161 // logically defined does not necessarily coincide exactly with program point
162 // of these place-holder instructions. So we explicitly define the following
163 // order for phi values:
164 //
165 // body/condition parameter phi:
166 // Defined before all values defined in its computation excepting other
167 // phis.
168 //
169 // while phi:
170 // defined after all values defined in the condition or body.
171 //
172 auto is_body_or_condition_phi = [](const HloValue& v) {
173 return v.is_phi() &&
174 v.defining_instruction()->opcode() == HloOpcode::kParameter;
175 };
176 if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
177 call_graph_->InstructionIsNestedIn(b.defining_instruction(),
178 a.defining_instruction()->parent())) {
179 return true;
180 }
181 if (is_body_or_condition_phi(b) &&
182 call_graph_->InstructionIsNestedIn(a.defining_instruction(),
183 b.defining_instruction()->parent())) {
184 return false;
185 }
186
187 // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
188 // executes before 'b'.
189 if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
190 (call_graph_->InstructionIsNestedIn(
191 a.defining_instruction(), b.defining_instruction()->while_body()) ||
192 call_graph_->InstructionIsNestedIn(
193 a.defining_instruction(),
194 b.defining_instruction()->while_condition()))) {
195 return true;
196 }
197 // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
198 // executes before 'b'.
199 if (b.is_phi() &&
200 b.defining_instruction()->opcode() == HloOpcode::kConditional) {
201 for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
202 if (call_graph_->InstructionIsNestedIn(
203 a.defining_instruction(),
204 b.defining_instruction()->branch_computation(j))) {
205 return true;
206 }
207 }
208 }
209 return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
210 }
211
212 /* static */
UsesBeforeValueDefinition(absl::Span<const HloUse * const> uses,const HloValue & value,const HloDataflowAnalysis & dataflow,bool use_is_always_before_def_in_same_instr) const213 bool HloOrdering::UsesBeforeValueDefinition(
214 absl::Span<const HloUse* const> uses, const HloValue& value,
215 const HloDataflowAnalysis& dataflow,
216 bool use_is_always_before_def_in_same_instr) const {
217 bool has_use_in_exclusive_branches = false;
218 bool has_escaped_use_in_conditional = false;
219 auto UseIsBeforeValueDefinition = [&](const HloUse& use) {
220 VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
221 << ", value=" << value.ToShortString() << ")";
222 switch (
223 GetExecutionConstraint(use.instruction, value.defining_instruction())) {
224 case HloOrdering::ExecutionConstraint::kIsSame:
225 // If the use is at the instruction where the value is defined, then the
226 // use is before the def if the instruction allows buffer sharing (in
227 // place computation).
228 if (use_is_always_before_def_in_same_instr ||
229 dataflow.CanShareOperandBufferWithUser(
230 use.instruction->mutable_operand(use.operand_number),
231 use.operand_index, value.defining_instruction(),
232 value.defining_index())) {
233 VLOG(4)
234 << " use is value def, and instruction can share use buffer.";
235 return true;
236 }
237 break;
238 case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
239 // If the use is located in a branch that is exclusive to the branch
240 // where value is located, in order for them to interfere, there must be
241 // an execution path where the value's definition can reach the use, so
242 // that the wrong value would reach use if their live ranges are merged.
243 // If there is such a path, it would have to pass through the point
244 // where the two exclusive branches are joined --- specifically the end
245 // of the conditional operation. For the join point to reach back to the
246 // use at the other exclusive branch, there has to be a be a surrounding
247 // loop, where the result of the conditional is passed back inside the
248 // conditional through one of its parameters. This use-def conflict
249 // between the parameter of a conditional and one of its branches is
250 // caught in the has_escaped_use_in_conditinoal variable.
251 VLOG(4) << " use and value def are in exclusive branches.";
252 if (!has_escaped_use_in_conditional) {
253 has_use_in_exclusive_branches = true;
254 VLOG(4) << "Allowing them to share buffer.\n";
255 return true;
256 }
257 VLOG(4) << "value def has escaped use in conditional. \n";
258 break;
259 case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
260 case HloOrdering::ExecutionConstraint::kRunBeforeStart:
261 case HloOrdering::ExecutionConstraint::kRunBeforeEnd:
262 VLOG(4)
263 << " use instruction executes before value-defining instruction";
264 return true;
265 case HloOrdering::ExecutionConstraint::kRunAfter:
266 // Treat CollectivePermuteDone as a special case as it shares the buffer
267 // from its operand (CollectivePermuteStart).
268 if (use_is_always_before_def_in_same_instr &&
269 use.instruction->opcode() == HloOpcode::kCollectivePermuteDone &&
270 use.instruction->operand(0) == value.instruction()) {
271 return true;
272 }
273 break;
274 case HloOrdering::ExecutionConstraint::kUnordered:
275 break;
276 }
277
278 // The use at a while is an input to a phi, and logically occurs before
279 // values are defined in the body. Note that the use is *not* before the
280 // value if the value is defined in the condition and is not the condition
281 // parameter, since the input of a while's live range is only ended at the
282 // start the body.
283 if (use.instruction->opcode() == HloOpcode::kWhile) {
284 const HloInstruction* xla_while = use.instruction;
285 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
286 xla_while->while_body())) {
287 VLOG(4) << " use is while " << use.instruction->name()
288 << " and def is in body";
289 return true;
290 }
291 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
292 xla_while->while_condition())) {
293 if (value.defining_instruction() !=
294 xla_while->while_condition()->parameter_instruction(0)) {
295 VLOG(4) << " use is while " << use.instruction->name()
296 << " and def is in condition and is not the parameter";
297 return false;
298 } else {
299 VLOG(4) << " use is while " << use.instruction->name()
300 << " and def is in condition and is the parameter";
301 return true;
302 }
303 }
304 }
305 // Similarly if the value is defined at a while, it logically occurs after
306 // any uses in the body or condition computations.
307 if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
308 CHECK(value.is_phi());
309 const HloInstruction* xla_while = value.defining_instruction();
310 if (call_graph_->InstructionIsNestedIn(use.instruction,
311 xla_while->while_body()) ||
312 call_graph_->InstructionIsNestedIn(use.instruction,
313 xla_while->while_condition())) {
314 VLOG(4) << " value is while " << value.defining_instruction()->name()
315 << " and use is in condition or body";
316 return true;
317 }
318 }
319 // The use at a call occurs before values that are defined in the called
320 // computation.
321 if (use.instruction->opcode() == HloOpcode::kCall) {
322 const HloInstruction* call = use.instruction;
323 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
324 call->to_apply())) {
325 VLOG(4) << " use is call " << use.instruction->name()
326 << " and def is in called computation";
327 return true;
328 }
329 }
330 if (use.instruction->opcode() == HloOpcode::kConditional) {
331 const HloInstruction* conditional = use.instruction;
332 // In general the use of a value in the conditional parameter should be
333 // considered to be before a definition in one of its branches, and
334 // therefore allowed in live range merging, if there is no
335 // surrounding loop that creates a backward control flow path that
336 // allows the definition in the branch to have its value flow backward
337 // into the conditional and then flow into another branch in the
338 // conditional that uses the value. This is reflected by checking that
339 // the use-def in exclusive branches has not been already allowed.
340 // Further, if the def value escapes its branch, we conservatively
341 // assume a backward control flow path could exist, and set
342 // has_escaped_use_in_conditinoal to disallow any later uses in
343 // exclusive branches.
344 for (int j = 0; j < conditional->branch_count(); ++j) {
345 if (call_graph_->InstructionIsNestedIn(
346 value.defining_instruction(),
347 conditional->branch_computation(j))) {
348 // If the use operand does not create a new value, and the value def
349 // is returned by as part of the result of the conditional, it
350 // is possible for the branch definition to flow backward through a
351 // surrounding loop and then back into the conditional parameter.
352 if (!dataflow.ValueIsDefinedAt(
353 use.instruction->operand(use.operand_number), {})) {
354 for (auto value_use : value.GetUses()) {
355 VLOG(4) << "def have use:" << value_use << "\n";
356 if (value_use.instruction ==
357 value_use.instruction->parent()->root_instruction()) {
358 VLOG(4) << "def use is conditional root \n";
359 has_escaped_use_in_conditional = true;
360 break;
361 }
362 }
363 }
364 if (!has_use_in_exclusive_branches) {
365 VLOG(4) << " use is conditional " << use.instruction->name()
366 << " and def is in " << j << "th branch computation";
367 return true;
368 }
369 }
370 }
371 if (value.defining_instruction() == use.instruction) {
372 VLOG(4) << " use is conditional " << use << " and def is "
373 << value.ToShortString();
374 return true;
375 }
376 }
377
378 VLOG(4) << " use is not before value definition";
379 return false;
380 };
381 for (auto* use : uses) {
382 if (!UseIsBeforeValueDefinition(*use)) {
383 return false;
384 }
385 }
386 return true;
387 }
388
LiveRangeStrictlyBefore(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow,bool use_is_always_before_def_in_same_instr) const389 bool HloOrdering::LiveRangeStrictlyBefore(
390 const HloValue& a, const HloValue& b, const HloDataflowAnalysis& dataflow,
391 bool use_is_always_before_def_in_same_instr) const {
392 VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
393 << ", b = " << b.ToShortString() << ")";
394 VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n";
395 if (!IsDefinedBefore(a, b)) {
396 VLOG(4) << a << " not defined before " << b;
397 return false;
398 }
399
400 if (a.live_out_of_module()) {
401 VLOG(4) << a << " is live out of module and not defined before " << b;
402 return false;
403 }
404
405 // If the root instruction aliases the buffer 'a', the live range of 'a' is
406 // until the end of the computation and can never be strictly before another
407 // buffer nested in the same computation. This is needed to prevent the root
408 // instruction's buffers from being reused by later instructions even when
409 // the root is not the last instruction in the schedule.
410 for (const HloPosition& pos : a.positions()) {
411 if (pos.instruction->parent()->root_instruction() == pos.instruction &&
412 call_graph().InstructionIsNestedIn(b.instruction(),
413 pos.instruction->parent())) {
414 return false;
415 }
416 }
417
418 // All uses of 'a' must be before 'b' is defined.
419 std::vector<const HloUse*> uses;
420 for (const HloUse& use : a.GetUses()) {
421 if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
422 use.instruction)) {
423 continue;
424 }
425 uses.push_back(&use);
426 }
427 if (!UsesBeforeValueDefinition(uses, b, dataflow,
428 use_is_always_before_def_in_same_instr)) {
429 VLOG(4) << "uses of " << a << "not before " << b << " is defined";
430 return false;
431 }
432
433 if (a.IsRootOf(b.instruction()->parent())) {
434 VLOG(4) << a << " is live out of computation and defined before " << b
435 << " which is in same computation";
436 return false;
437 }
438
439 return true;
440 }
441
MayInterfere(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const442 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
443 const HloDataflowAnalysis& dataflow) const {
444 // Buffers without disjoint liveness may interfere.
445 return !LiveRangeStrictlyBefore(a, b, dataflow) &&
446 !LiveRangeStrictlyBefore(b, a, dataflow);
447 }
448
PredecessorHloOrdering(const HloModule * module)449 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
450 : HloOrdering(module) {}
451
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const452 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
453 const HloInstruction* a, const HloInstruction* b) const {
454 CHECK_EQ(a->parent(), b->parent());
455
456 // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
457 return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
458 }
459
ToStringHelper(const std::string & name) const460 std::string PredecessorHloOrdering::ToStringHelper(
461 const std::string& name) const {
462 std::vector<std::string> pieces;
463 pieces.push_back(name);
464 for (auto* computation : module_->MakeNonfusionComputations()) {
465 pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
466 const auto all = computation->MakeInstructionPostOrder();
467 for (auto instruction : all) {
468 pieces.push_back(
469 absl::StrFormat(" %s predecessors:", instruction->name()));
470 for (auto predecessor : all) {
471 if (predecessors_.at(computation)
472 ->IsReachable(predecessor, instruction)) {
473 pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
474 }
475 }
476 }
477 }
478 return absl::StrJoin(pieces, "\n");
479 }
480
DependencyHloOrdering(const HloModule * module)481 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
482 : PredecessorHloOrdering(module) {
483 // Compute predecessor relationships between all instructions to determine
484 // ordering based on dependencies. ExecutesBefore will return true iff there
485 // exists a path in the HLO computation graph from 'a' to 'b'.
486 for (auto* computation : module->MakeNonfusionComputations()) {
487 predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
488 }
489 }
490
ToString() const491 std::string DependencyHloOrdering::ToString() const {
492 return ToStringHelper("DependencyHloOrdering");
493 }
494
SequentialHloOrdering(const HloSchedule & schedule)495 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
496 : HloOrdering(schedule.module()), schedule_(schedule) {
497 Initialize();
498 }
499
SequentialHloOrdering(HloSchedule && schedule)500 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
501 : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
502 Initialize();
503 }
504
Initialize()505 void SequentialHloOrdering::Initialize() {
506 // Create a map from instruction to its order position.
507 TF_DCHECK_OK(schedule_.Verify());
508 for (const auto& computation_sequence : schedule_.sequences()) {
509 const auto& order = computation_sequence.second.instructions();
510 for (int i = 0; i < order.size(); ++i) {
511 InsertOrDie(&order_position_, order[i], i);
512 }
513 }
514 }
515
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const516 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
517 const HloInstruction* a, const HloInstruction* b) const {
518 CHECK_EQ(a->parent(), b->parent());
519 // If either instruction is not in the order, then 'a' and 'b' are unordered.
520 if (!order_position_.contains(a) || !order_position_.contains(b)) {
521 return false;
522 }
523 if (a->parent()->root_instruction() == a) {
524 // 'a' is the root instruction of the computation, which lives out. So
525 // 'a' cannot execute before 'b'.
526 return false;
527 }
528 return order_position_.at(a) < order_position_.at(b);
529 }
530
SequentialOrder(const HloComputation & computation) const531 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
532 const HloComputation& computation) const {
533 return schedule_.is_computation_scheduled(&computation)
534 ? &schedule_.sequence(&computation)
535 : nullptr;
536 }
537
ToString() const538 std::string SequentialHloOrdering::ToString() const {
539 return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
540 }
541
542 } // namespace xla
543