1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
ExecutesBefore(const HloInstruction * a,const HloInstruction * b) const35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
36 const HloInstruction* b) const {
37 switch (GetExecutionConstraint(a, b)) {
38 case ExecutionConstraint::kIsSame: // a and b are the same instruction;
39 return false;
40 case ExecutionConstraint::kRunBefore:
41 case ExecutionConstraint::kRunExclusiveBefore:
42 return true;
43 case ExecutionConstraint::kRunExclusiveAfter:
44 case ExecutionConstraint::kRunAfter:
45 case ExecutionConstraint::kUnordered:
46 return false;
47 }
48 }
49
GetExecutionConstraint(const HloInstruction * a,const HloInstruction * b) const50 HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint(
51 const HloInstruction* a, const HloInstruction* b) const {
52 // 'a' and 'b' may be in different computations. In this case, find the
53 // callgraph ancestor instructions which call (potentially transitively) the
54 // computations containing 'a' and 'b' and use these ancestor instructions to
55 // compare order.
56 if (a == b) {
57 return ExecutionConstraint::kIsSame;
58 }
59 const HloInstruction* a_ancestor;
60 const HloInstruction* b_ancestor;
61 std::tie(a_ancestor, b_ancestor) =
62 call_graph_->NearestAncestorsInSameComputation(
63 const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
64
65 if (a_ancestor == nullptr) {
66 VLOG(4) << "Ancestors in a common computation could not be found between"
67 << a->ToString() << "\n and \n"
68 << b->ToString() << "\n so consider them to be unordered.\n";
69 return ExecutionConstraint::kUnordered;
70 }
71 // a_ancestor and b_ancestor must be either both null or both non-null.
72 CHECK_NE(b_ancestor, nullptr);
73 CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
74
75 // If the common ancestor is a while instruction there is an additional
76 // ordering criteria which may apply. The condition computation is considered
77 // to execute before the body computation so if 'a' is in the condition and
78 // 'b' is in the body, then 'a' executes before 'b'.
79 if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
80 const HloComputation* body = a_ancestor->while_body();
81 const HloComputation* condition = a_ancestor->while_condition();
82 if (call_graph_->InstructionIsNestedIn(a, condition) &&
83 call_graph_->InstructionIsNestedIn(b, body)) {
84 return ExecutionConstraint::kRunBefore;
85 }
86 }
87
88 // If the common ancestor is a conditional instruction, even though the branch
89 // computations are not really ordered per-se, we define the 0th branch
90 // computation to be ordered before the 1st one, before the 2nd and so forth.
91 // This ensures that buffers can still be shared among branch computations
92 // as they will forcibly have disjoint liveness.
93 if (a_ancestor == b_ancestor &&
94 (a_ancestor->opcode() == HloOpcode::kConditional)) {
95 int a_branch = -1;
96 int b_branch = -1;
97 for (int j = 0; j < a_ancestor->branch_count(); ++j) {
98 if (call_graph_->InstructionIsNestedIn(
99 a, a_ancestor->branch_computation(j))) {
100 a_branch = j;
101 }
102 if (call_graph_->InstructionIsNestedIn(
103 b, a_ancestor->branch_computation(j))) {
104 b_branch = j;
105 }
106 }
107 // If neither a nor b is inside the branches they both are the ancestor.
108 if (a_branch == -1 && b_branch == -1) {
109 CHECK_EQ(a, a_ancestor);
110 CHECK_EQ(b, b_ancestor);
111 CHECK_EQ(a, b);
112 return ExecutionConstraint::kIsSame;
113 }
114 // If 'b' is the conditional ancestor, and 'a' is within a branch
115 // computation, 'a' executes before 'b'.
116 if (b_branch == -1) {
117 CHECK_EQ(b, a_ancestor);
118 return ExecutionConstraint::kRunBefore;
119 }
120 if (a_branch == -1) {
121 CHECK_EQ(a, a_ancestor);
122 return ExecutionConstraint::kRunAfter;
123 }
124 if (a_branch < b_branch) {
125 return ExecutionConstraint::kRunExclusiveBefore;
126 }
127 if (b_branch < a_branch) {
128 return ExecutionConstraint::kRunExclusiveAfter;
129 }
130 }
131
132 if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) {
133 return ExecutionConstraint::kRunBefore;
134 }
135 if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) {
136 return ExecutionConstraint::kRunAfter;
137 }
138 VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n"
139 << "and " << b->ToString() << " which are in the same computation\n";
140 return ExecutionConstraint::kUnordered;
141 }
142
IsDefinedBefore(const HloValue & a,const HloValue & b) const143 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
144 // Entry parameter should always be defined before other instructions.
145 const HloModule* module = b.defining_instruction()->parent()->parent();
146 if (b.defining_instruction()->parent() == module->entry_computation() &&
147 b.defining_instruction()->opcode() == HloOpcode::kParameter) {
148 return false;
149 }
150
151 if (a.defining_instruction()->parent() == module->entry_computation() &&
152 a.defining_instruction()->opcode() == HloOpcode::kParameter) {
153 return true;
154 }
155
156 // Phi values require special handling. Because XLA does not have a phi
157 // instruction, the definition instruction of the phis values are
158 // placeholders: either the subcomputation parameter (body or condition) or
159 // the while instruction. However, the program point where these values are
160 // logically defined does not necessarily coincide exactly with program point
161 // of these place-holder instructions. So we explicitly define the following
162 // order for phi values:
163 //
164 // body/condition parameter phi:
165 // Defined before all values defined in its computation excepting other
166 // phis.
167 //
168 // while phi:
169 // defined after all values defined in the condition or body.
170 //
171 auto is_body_or_condition_phi = [](const HloValue& v) {
172 return v.is_phi() &&
173 v.defining_instruction()->opcode() == HloOpcode::kParameter;
174 };
175 if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
176 call_graph_->InstructionIsNestedIn(b.defining_instruction(),
177 a.defining_instruction()->parent())) {
178 return true;
179 }
180 if (is_body_or_condition_phi(b) &&
181 call_graph_->InstructionIsNestedIn(a.defining_instruction(),
182 b.defining_instruction()->parent())) {
183 return false;
184 }
185
186 // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
187 // executes before 'b'.
188 if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
189 (call_graph_->InstructionIsNestedIn(
190 a.defining_instruction(), b.defining_instruction()->while_body()) ||
191 call_graph_->InstructionIsNestedIn(
192 a.defining_instruction(),
193 b.defining_instruction()->while_condition()))) {
194 return true;
195 }
196 // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
197 // executes before 'b'.
198 if (b.is_phi() &&
199 b.defining_instruction()->opcode() == HloOpcode::kConditional) {
200 for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
201 if (call_graph_->InstructionIsNestedIn(
202 a.defining_instruction(),
203 b.defining_instruction()->branch_computation(j))) {
204 return true;
205 }
206 }
207 }
208 return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
209 }
210
211 /* static */
UsesBeforeValueDefinition(absl::Span<const HloUse * const> uses,const HloValue & value,const HloDataflowAnalysis & dataflow) const212 bool HloOrdering::UsesBeforeValueDefinition(
213 absl::Span<const HloUse* const> uses, const HloValue& value,
214 const HloDataflowAnalysis& dataflow) const {
215 bool has_use_in_exclusive_branches = false;
216 bool has_escaped_use_in_conditional = false;
217 auto UseIsBeforeValueDefinition = [&](const HloUse& use) {
218 VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
219 << ", value=" << value.ToShortString() << ")";
220 switch (
221 GetExecutionConstraint(use.instruction, value.defining_instruction())) {
222 case HloOrdering::ExecutionConstraint::kIsSame:
223 // If the use is at the instruction where the value is defined, then the
224 // use is before the def if the instruction allows buffer sharing (in
225 // place computation).
226 if (dataflow.CanShareOperandBufferWithUser(
227 use.instruction->mutable_operand(use.operand_number),
228 use.operand_index, value.defining_instruction(),
229 value.defining_index())) {
230 VLOG(4)
231 << " use is value def, and instruction can share use buffer.";
232 return true;
233 }
234 break;
235 case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
236 // If the use is located in a branch that is exclusive to the branch
237 // where value is located, in order for them to interfere, there must be
238 // an execution path where the value's definition can reach the use, so
239 // that the wrong value would reach use if their live ranges are merged.
240 // If there is such a path, it would have to pass through the point
241 // where the two exclusive branches are joined --- specifically the end
242 // of the conditional operation. For the join point to reach back to the
243 // use at the other exclusive branch, there has to be a be a surrounding
244 // loop, where the result of the conditional is passed back inside the
245 // conditional through one of its parameters. This use-def conflict
246 // between the parameter of a conditional and one of its branches is
247 // caught in the has_escaped_use_in_conditinoal variable.
248 VLOG(4) << " use and value def are in exclusive branches.";
249 if (!has_escaped_use_in_conditional) {
250 has_use_in_exclusive_branches = true;
251 VLOG(4) << "Allowing them to share buffer.\n";
252 return true;
253 }
254 VLOG(4) << "value def has escaped use in conditional. \n";
255 break;
256 case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
257 case HloOrdering::ExecutionConstraint::kRunBefore:
258 VLOG(4)
259 << " use instruction executes before value-defining instruction";
260 return true;
261 case HloOrdering::ExecutionConstraint::kRunAfter:
262 case HloOrdering::ExecutionConstraint::kUnordered:
263 break;
264 }
265
266 // The use at a while is an input to a phi, and logically occurs before
267 // values are defined in the body. Note that the use is *not* before the
268 // value if the value is defined in the condition and is not the condition
269 // parameter, since the input of a while's live range is only ended at the
270 // start the body.
271 if (use.instruction->opcode() == HloOpcode::kWhile) {
272 const HloInstruction* xla_while = use.instruction;
273 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
274 xla_while->while_body())) {
275 VLOG(4) << " use is while " << use.instruction->name()
276 << " and def is in body";
277 return true;
278 }
279 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
280 xla_while->while_condition())) {
281 if (value.defining_instruction() !=
282 xla_while->while_condition()->parameter_instruction(0)) {
283 VLOG(4) << " use is while " << use.instruction->name()
284 << " and def is in condition and is not the parameter";
285 return false;
286 } else {
287 VLOG(4) << " use is while " << use.instruction->name()
288 << " and def is in condition and is the parameter";
289 return true;
290 }
291 }
292 }
293 // Similarly if the value is defined at a while, it logically occurs after
294 // any uses in the body or condition computations.
295 if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
296 CHECK(value.is_phi());
297 const HloInstruction* xla_while = value.defining_instruction();
298 if (call_graph_->InstructionIsNestedIn(use.instruction,
299 xla_while->while_body()) ||
300 call_graph_->InstructionIsNestedIn(use.instruction,
301 xla_while->while_condition())) {
302 VLOG(4) << " value is while " << value.defining_instruction()->name()
303 << " and use is in condition or body";
304 return true;
305 }
306 }
307 // The use at a call occurs before values that are defined in the called
308 // computation.
309 if (use.instruction->opcode() == HloOpcode::kCall) {
310 const HloInstruction* call = use.instruction;
311 if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
312 call->to_apply())) {
313 VLOG(4) << " use is call " << use.instruction->name()
314 << " and def is in called computation";
315 return true;
316 }
317 }
318 if (use.instruction->opcode() == HloOpcode::kConditional) {
319 const HloInstruction* conditional = use.instruction;
320 // In general the use of a value in the conditional parameter should be
321 // considered to be before a definition in one of its branches, and
322 // therefore allowed in live range merging, if there is no
323 // surrounding loop that creates a backward control flow path that
324 // allows the definition in the branch to have its value flow backward
325 // into the conditional and then flow into another branch in the
326 // conditional that uses the value. This is reflected by checking that
327 // the use-def in exclusive branches has not been already allowed.
328 // Further, if the def value escapes its branch, we conservatively
329 // assume a backward control flow path could exist, and set
330 // has_escaped_use_in_conditinoal to disallow any later uses in
331 // exclusive branches.
332 for (int j = 0; j < conditional->branch_count(); ++j) {
333 if (call_graph_->InstructionIsNestedIn(
334 value.defining_instruction(),
335 conditional->branch_computation(j))) {
336 // If the use operand does not create a new value, and the value def
337 // is returned by as part of the result of the conditional, it
338 // is possible for the branch definition to flow backward through a
339 // surrounding loop and then back into the conditional parameter.
340 if (!dataflow.ValueIsDefinedAt(
341 use.instruction->operand(use.operand_number), {})) {
342 for (auto value_use : value.uses()) {
343 VLOG(4) << "def have use:" << value_use << "\n";
344 if (value_use.instruction ==
345 value_use.instruction->parent()->root_instruction()) {
346 VLOG(4) << "def use is conditional root \n";
347 has_escaped_use_in_conditional = true;
348 break;
349 }
350 }
351 }
352 if (!has_use_in_exclusive_branches) {
353 VLOG(4) << " use is conditional " << use.instruction->name()
354 << " and def is in " << j << "th branch computation";
355 return true;
356 }
357 }
358 }
359 if (value.defining_instruction() == use.instruction) {
360 VLOG(4) << " use is conditional " << use << " and def is "
361 << value.ToShortString();
362 return true;
363 }
364 }
365
366 VLOG(4) << " use is not before value definition";
367 return false;
368 };
369 for (auto* use : uses) {
370 if (!UseIsBeforeValueDefinition(*use)) {
371 return false;
372 }
373 }
374 return true;
375 }
376
LiveRangeStrictlyBefore(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const377 bool HloOrdering::LiveRangeStrictlyBefore(
378 const HloValue& a, const HloValue& b,
379 const HloDataflowAnalysis& dataflow) const {
380 VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
381 << ", b = " << b.ToShortString() << ")";
382 VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n";
383 if (!IsDefinedBefore(a, b)) {
384 VLOG(4) << a << " not defined before " << b;
385 return false;
386 }
387
388 if (a.live_out_of_module()) {
389 VLOG(4) << a << " is live out of module and not defined before " << b;
390 return false;
391 }
392
393 // If the root instruction aliases the buffer 'a', the live range of 'a' is
394 // until the end of the computation and can never be strictly before another
395 // buffer nested in the same computation. This is needed to prevent the root
396 // instruction's buffers from being reused by later instructions even when
397 // the root is not the last instruction in the schedule.
398 for (const HloPosition& pos : a.positions()) {
399 if (pos.instruction->parent()->root_instruction() == pos.instruction &&
400 call_graph().InstructionIsNestedIn(b.instruction(),
401 pos.instruction->parent())) {
402 return false;
403 }
404 }
405
406 // All uses of 'a' must be before 'b' is defined.
407 std::vector<const HloUse*> uses;
408 for (const HloUse& use : a.uses()) {
409 if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
410 use.instruction)) {
411 continue;
412 }
413 uses.push_back(&use);
414 }
415 if (!UsesBeforeValueDefinition(uses, b, dataflow)) {
416 VLOG(4) << "uses of " << a << "not before " << b << " is defined";
417 return false;
418 }
419
420 if (a.instruction()->parent() == b.instruction()->parent()) {
421 for (const HloPosition& position : a.positions()) {
422 if (position.instruction ==
423 a.instruction()->parent()->root_instruction()) {
424 VLOG(4) << a << " is live out of computation and defined before " << b
425 << " which is in same computation";
426 return false;
427 }
428 }
429 }
430
431 return true;
432 }
433
MayInterfere(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const434 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
435 const HloDataflowAnalysis& dataflow) const {
436 // Buffers without disjoint liveness may interfere.
437 return !LiveRangeStrictlyBefore(a, b, dataflow) &&
438 !LiveRangeStrictlyBefore(b, a, dataflow);
439 }
440
PredecessorHloOrdering(const HloModule * module)441 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
442 : HloOrdering(module) {}
443
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const444 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
445 const HloInstruction* a, const HloInstruction* b) const {
446 CHECK_EQ(a->parent(), b->parent());
447
448 // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
449 return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
450 }
451
ToStringHelper(const string & name) const452 string PredecessorHloOrdering::ToStringHelper(const string& name) const {
453 std::vector<string> pieces;
454 pieces.push_back(name);
455 for (auto* computation : module_->MakeNonfusionComputations()) {
456 pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
457 const auto all = computation->MakeInstructionPostOrder();
458 for (auto instruction : all) {
459 pieces.push_back(
460 absl::StrFormat(" %s predecessors:", instruction->name()));
461 for (auto predecessor : all) {
462 if (predecessors_.at(computation)
463 ->IsReachable(predecessor, instruction)) {
464 pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
465 }
466 }
467 }
468 }
469 return absl::StrJoin(pieces, "\n");
470 }
471
DependencyHloOrdering(const HloModule * module)472 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
473 : PredecessorHloOrdering(module) {
474 // Compute predecessor relationships between all instructions to determine
475 // ordering based on dependencies. ExecutesBefore will return true iff there
476 // exists a path in the HLO computation graph from 'a' to 'b'.
477 for (auto* computation : module->MakeNonfusionComputations()) {
478 predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
479 }
480 }
481
ToString() const482 string DependencyHloOrdering::ToString() const {
483 return ToStringHelper("DependencyHloOrdering");
484 }
485
SequentialHloOrdering(const HloSchedule & schedule)486 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
487 : HloOrdering(schedule.module()), schedule_(schedule) {
488 Initialize();
489 }
490
SequentialHloOrdering(HloSchedule && schedule)491 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
492 : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
493 Initialize();
494 }
495
Initialize()496 void SequentialHloOrdering::Initialize() {
497 // Create a map from instruction to its order position.
498 TF_DCHECK_OK(schedule_.Verify());
499 for (const auto& computation_sequence : schedule_.sequences()) {
500 const auto& order = computation_sequence.second.instructions();
501 for (int i = 0; i < order.size(); ++i) {
502 InsertOrDie(&order_position_, order[i], i);
503 }
504 }
505 }
506
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const507 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
508 const HloInstruction* a, const HloInstruction* b) const {
509 CHECK_EQ(a->parent(), b->parent());
510 // If either instruction is not in the order, then 'a' and 'b' are unordered.
511 if (!order_position_.contains(a) || !order_position_.contains(b)) {
512 return false;
513 }
514 return order_position_.at(a) < order_position_.at(b);
515 }
516
SequentialOrder(const HloComputation & computation) const517 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
518 const HloComputation& computation) const {
519 return schedule_.is_computation_scheduled(&computation)
520 ? &schedule_.sequence(&computation)
521 : nullptr;
522 }
523
ToString() const524 string SequentialHloOrdering::ToString() const {
525 return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
526 }
527
528 } // namespace xla
529