• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
ExecutesBefore(const HloInstruction * a,const HloInstruction * b) const35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
36                                  const HloInstruction* b) const {
37   switch (GetExecutionConstraint(a, b)) {
38     case ExecutionConstraint::kIsSame:  // a and b are the same instruction;
39       return false;
40     case ExecutionConstraint::kRunBeforeStart:
41     case ExecutionConstraint::kRunBeforeEnd:
42     case ExecutionConstraint::kRunExclusiveBefore:
43       return true;
44     case ExecutionConstraint::kRunExclusiveAfter:
45     case ExecutionConstraint::kRunAfter:
46     case ExecutionConstraint::kUnordered:
47       return false;
48   }
49 }
50 
GetExecutionConstraint(const HloInstruction * a,const HloInstruction * b) const51 HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint(
52     const HloInstruction* a, const HloInstruction* b) const {
53   // 'a' and 'b' may be in different computations. In this case, find the
54   // callgraph ancestor instructions which call (potentially transitively) the
55   // computations containing 'a' and 'b' and use these ancestor instructions to
56   // compare order.
57   if (a == b) {
58     return ExecutionConstraint::kIsSame;
59   }
60   const HloInstruction* a_ancestor;
61   const HloInstruction* b_ancestor;
62   std::tie(a_ancestor, b_ancestor) =
63       call_graph_->NearestAncestorsInSameComputation(
64           const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
65 
66   if (a_ancestor == nullptr) {
67     VLOG(4) << "Ancestors in a common computation could not be found between"
68             << a->ToString() << "\n and \n"
69             << b->ToString() << "\n so consider them to be unordered.\n";
70     return ExecutionConstraint::kUnordered;
71   }
72   // a_ancestor and b_ancestor must be either both null or both non-null.
73   CHECK_NE(b_ancestor, nullptr);
74   CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
75 
76   // If the common ancestor is a while instruction there is an additional
77   // ordering criteria which may apply. The condition computation is considered
78   // to execute before the body computation so if 'a' is in the condition and
79   // 'b' is in the body, then 'a' executes before 'b'.
80   if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
81     const HloComputation* body = a_ancestor->while_body();
82     const HloComputation* condition = a_ancestor->while_condition();
83     if (call_graph_->InstructionIsNestedIn(a, condition) &&
84         call_graph_->InstructionIsNestedIn(b, body)) {
85       return ExecutionConstraint::kRunBeforeEnd;
86     }
87   }
88 
89   // If the common ancestor is a conditional instruction, even though the branch
90   // computations are not really ordered per-se, we define the 0th branch
91   // computation to be ordered before the 1st one, before the 2nd and so forth.
92   // This ensures that buffers can still be shared among branch computations
93   // as they will forcibly have disjoint liveness.
94   if (a_ancestor == b_ancestor &&
95       (a_ancestor->opcode() == HloOpcode::kConditional)) {
96     int a_branch = -1;
97     int b_branch = -1;
98     for (int j = 0; j < a_ancestor->branch_count(); ++j) {
99       if (call_graph_->InstructionIsNestedIn(
100               a, a_ancestor->branch_computation(j))) {
101         a_branch = j;
102       }
103       if (call_graph_->InstructionIsNestedIn(
104               b, a_ancestor->branch_computation(j))) {
105         b_branch = j;
106       }
107     }
108     // If neither a nor b is inside the branches they both are the ancestor.
109     if (a_branch == -1 && b_branch == -1) {
110       CHECK_EQ(a, a_ancestor);
111       CHECK_EQ(b, b_ancestor);
112       CHECK_EQ(a, b);
113       return ExecutionConstraint::kIsSame;
114     }
115     // If 'b' is the conditional ancestor, and 'a' is within a branch
116     // computation, 'a' executes before 'b'.
117     if (b_branch == -1) {
118       CHECK_EQ(b, a_ancestor);
119       return ExecutionConstraint::kRunBeforeEnd;
120     }
121     if (a_branch == -1) {
122       CHECK_EQ(a, a_ancestor);
123       return ExecutionConstraint::kRunAfter;
124     }
125     if (a_branch < b_branch) {
126       return ExecutionConstraint::kRunExclusiveBefore;
127     }
128     if (b_branch < a_branch) {
129       return ExecutionConstraint::kRunExclusiveAfter;
130     }
131   }
132 
133   if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) {
134     return ExecutionConstraint::kRunBeforeStart;
135   }
136   if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) {
137     return ExecutionConstraint::kRunAfter;
138   }
139   VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n"
140           << "and " << b->ToString() << " which are in the same computation\n";
141   return ExecutionConstraint::kUnordered;
142 }
143 
IsDefinedBefore(const HloValue & a,const HloValue & b) const144 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
145   // Entry parameter should always be defined before other instructions.
146   const HloModule* module = b.defining_instruction()->parent()->parent();
147   if (b.defining_instruction()->parent() == module->entry_computation() &&
148       b.defining_instruction()->opcode() == HloOpcode::kParameter) {
149     return false;
150   }
151 
152   if (a.defining_instruction()->parent() == module->entry_computation() &&
153       a.defining_instruction()->opcode() == HloOpcode::kParameter) {
154     return true;
155   }
156 
157   // Phi values require special handling. Because XLA does not have a phi
158   // instruction, the definition instruction of the phis values are
159   // placeholders: either the subcomputation parameter (body or condition) or
160   // the while instruction. However, the program point where these values are
161   // logically defined does not necessarily coincide exactly with program point
162   // of these place-holder instructions. So we explicitly define the following
163   // order for phi values:
164   //
165   //   body/condition parameter phi:
166   //     Defined before all values defined in its computation excepting other
167   //     phis.
168   //
169   //   while phi:
170   //     defined after all values defined in the condition or body.
171   //
172   auto is_body_or_condition_phi = [](const HloValue& v) {
173     return v.is_phi() &&
174            v.defining_instruction()->opcode() == HloOpcode::kParameter;
175   };
176   if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
177       call_graph_->InstructionIsNestedIn(b.defining_instruction(),
178                                          a.defining_instruction()->parent())) {
179     return true;
180   }
181   if (is_body_or_condition_phi(b) &&
182       call_graph_->InstructionIsNestedIn(a.defining_instruction(),
183                                          b.defining_instruction()->parent())) {
184     return false;
185   }
186 
187   // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
188   // executes before 'b'.
189   if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
190       (call_graph_->InstructionIsNestedIn(
191            a.defining_instruction(), b.defining_instruction()->while_body()) ||
192        call_graph_->InstructionIsNestedIn(
193            a.defining_instruction(),
194            b.defining_instruction()->while_condition()))) {
195     return true;
196   }
197   // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
198   // executes before 'b'.
199   if (b.is_phi() &&
200       b.defining_instruction()->opcode() == HloOpcode::kConditional) {
201     for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
202       if (call_graph_->InstructionIsNestedIn(
203               a.defining_instruction(),
204               b.defining_instruction()->branch_computation(j))) {
205         return true;
206       }
207     }
208   }
209   return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
210 }
211 
212 /* static */
UsesBeforeValueDefinition(absl::Span<const HloUse * const> uses,const HloValue & value,const HloDataflowAnalysis & dataflow,bool use_is_always_before_def_in_same_instr) const213 bool HloOrdering::UsesBeforeValueDefinition(
214     absl::Span<const HloUse* const> uses, const HloValue& value,
215     const HloDataflowAnalysis& dataflow,
216     bool use_is_always_before_def_in_same_instr) const {
217   bool has_use_in_exclusive_branches = false;
218   bool has_escaped_use_in_conditional = false;
219   auto UseIsBeforeValueDefinition = [&](const HloUse& use) {
220     VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
221             << ", value=" << value.ToShortString() << ")";
222     switch (
223         GetExecutionConstraint(use.instruction, value.defining_instruction())) {
224       case HloOrdering::ExecutionConstraint::kIsSame:
225         // If the use is at the instruction where the value is defined, then the
226         // use is before the def if the instruction allows buffer sharing (in
227         // place computation).
228         if (use_is_always_before_def_in_same_instr ||
229             dataflow.CanShareOperandBufferWithUser(
230                 use.instruction->mutable_operand(use.operand_number),
231                 use.operand_index, value.defining_instruction(),
232                 value.defining_index())) {
233           VLOG(4)
234               << "  use is value def, and instruction can share use buffer.";
235           return true;
236         }
237         break;
238       case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
239         // If the use is located in a branch that is exclusive to the branch
240         // where value is located, in order for them to interfere, there must be
241         // an execution path where the value's definition can reach the use, so
242         // that the wrong value would reach use if their live ranges are merged.
243         // If there is such a path, it would have to pass through the point
244         // where the two exclusive branches are joined --- specifically the end
245         // of the conditional operation. For the join point to reach back to the
246         // use at the other exclusive branch, there has to be a be a surrounding
247         // loop, where the result of the conditional is passed back inside the
248         // conditional through one of its parameters. This use-def conflict
249         // between the parameter of a conditional and one of its branches is
250         // caught in the has_escaped_use_in_conditinoal variable.
251         VLOG(4) << " use and value def are in exclusive branches.";
252         if (!has_escaped_use_in_conditional) {
253           has_use_in_exclusive_branches = true;
254           VLOG(4) << "Allowing them to share buffer.\n";
255           return true;
256         }
257         VLOG(4) << "value def has escaped use in conditional. \n";
258         break;
259       case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
260       case HloOrdering::ExecutionConstraint::kRunBeforeStart:
261       case HloOrdering::ExecutionConstraint::kRunBeforeEnd:
262         VLOG(4)
263             << "  use instruction executes before value-defining instruction";
264         return true;
265       case HloOrdering::ExecutionConstraint::kRunAfter:
266         // Treat CollectivePermuteDone as a special case as it shares the buffer
267         // from its operand (CollectivePermuteStart).
268         if (use_is_always_before_def_in_same_instr &&
269             use.instruction->opcode() == HloOpcode::kCollectivePermuteDone &&
270             use.instruction->operand(0) == value.instruction()) {
271           return true;
272         }
273         break;
274       case HloOrdering::ExecutionConstraint::kUnordered:
275         break;
276     }
277 
278     // The use at a while is an input to a phi, and logically occurs before
279     // values are defined in the body. Note that the use is *not* before the
280     // value if the value is defined in the condition and is not the condition
281     // parameter, since the input of a while's live range is only ended at the
282     // start the body.
283     if (use.instruction->opcode() == HloOpcode::kWhile) {
284       const HloInstruction* xla_while = use.instruction;
285       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
286                                              xla_while->while_body())) {
287         VLOG(4) << "  use is while " << use.instruction->name()
288                 << " and def is in body";
289         return true;
290       }
291       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
292                                              xla_while->while_condition())) {
293         if (value.defining_instruction() !=
294             xla_while->while_condition()->parameter_instruction(0)) {
295           VLOG(4) << "  use is while " << use.instruction->name()
296                   << " and def is in condition and is not the parameter";
297           return false;
298         } else {
299           VLOG(4) << "  use is while " << use.instruction->name()
300                   << " and def is in condition and is the parameter";
301           return true;
302         }
303       }
304     }
305     // Similarly if the value is defined at a while, it logically occurs after
306     // any uses in the body or condition computations.
307     if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
308       CHECK(value.is_phi());
309       const HloInstruction* xla_while = value.defining_instruction();
310       if (call_graph_->InstructionIsNestedIn(use.instruction,
311                                              xla_while->while_body()) ||
312           call_graph_->InstructionIsNestedIn(use.instruction,
313                                              xla_while->while_condition())) {
314         VLOG(4) << "  value is while " << value.defining_instruction()->name()
315                 << " and use is in condition or body";
316         return true;
317       }
318     }
319     // The use at a call occurs before values that are defined in the called
320     // computation.
321     if (use.instruction->opcode() == HloOpcode::kCall) {
322       const HloInstruction* call = use.instruction;
323       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
324                                              call->to_apply())) {
325         VLOG(4) << "  use is call " << use.instruction->name()
326                 << " and def is in called computation";
327         return true;
328       }
329     }
330     if (use.instruction->opcode() == HloOpcode::kConditional) {
331       const HloInstruction* conditional = use.instruction;
332       // In general the use of a value in the conditional parameter should be
333       // considered to be before a definition in one of its branches, and
334       // therefore allowed in live range merging, if there is no
335       // surrounding loop that creates a backward control flow path that
336       // allows the definition in the branch to have its value flow backward
337       // into the conditional and then flow into another branch in the
338       // conditional that uses the value. This is reflected by checking that
339       // the use-def in exclusive branches has not been already allowed.
340       // Further, if the def value escapes its branch, we conservatively
341       // assume a backward control flow path could exist, and set
342       // has_escaped_use_in_conditinoal to disallow any later uses in
343       // exclusive branches.
344       for (int j = 0; j < conditional->branch_count(); ++j) {
345         if (call_graph_->InstructionIsNestedIn(
346                 value.defining_instruction(),
347                 conditional->branch_computation(j))) {
348           // If the use operand does not create a new value, and the value def
349           // is returned by as part of the result of the conditional, it
350           // is possible for the branch definition to flow backward through a
351           // surrounding loop and then back into the conditional parameter.
352           if (!dataflow.ValueIsDefinedAt(
353                   use.instruction->operand(use.operand_number), {})) {
354             for (auto value_use : value.GetUses()) {
355               VLOG(4) << "def have use:" << value_use << "\n";
356               if (value_use.instruction ==
357                   value_use.instruction->parent()->root_instruction()) {
358                 VLOG(4) << "def use is conditional root \n";
359                 has_escaped_use_in_conditional = true;
360                 break;
361               }
362             }
363           }
364           if (!has_use_in_exclusive_branches) {
365             VLOG(4) << "  use is conditional " << use.instruction->name()
366                     << " and def is in " << j << "th branch computation";
367             return true;
368           }
369         }
370       }
371       if (value.defining_instruction() == use.instruction) {
372         VLOG(4) << "  use is conditional " << use << " and def is "
373                 << value.ToShortString();
374         return true;
375       }
376     }
377 
378     VLOG(4) << "  use is not before value definition";
379     return false;
380   };
381   for (auto* use : uses) {
382     if (!UseIsBeforeValueDefinition(*use)) {
383       return false;
384     }
385   }
386   return true;
387 }
388 
LiveRangeStrictlyBefore(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow,bool use_is_always_before_def_in_same_instr) const389 bool HloOrdering::LiveRangeStrictlyBefore(
390     const HloValue& a, const HloValue& b, const HloDataflowAnalysis& dataflow,
391     bool use_is_always_before_def_in_same_instr) const {
392   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
393           << ", b = " << b.ToShortString() << ")";
394   VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n";
395   if (!IsDefinedBefore(a, b)) {
396     VLOG(4) << a << " not defined before " << b;
397     return false;
398   }
399 
400   if (a.live_out_of_module()) {
401     VLOG(4) << a << " is live out of module and not defined before " << b;
402     return false;
403   }
404 
405   // If the root instruction aliases the buffer 'a', the live range of 'a' is
406   // until the end of the computation and can never be strictly before another
407   // buffer nested in the same computation. This is needed to prevent the root
408   // instruction's buffers from being reused by later instructions even when
409   // the root is not the last instruction in the schedule.
410   for (const HloPosition& pos : a.positions()) {
411     if (pos.instruction->parent()->root_instruction() == pos.instruction &&
412         call_graph().InstructionIsNestedIn(b.instruction(),
413                                            pos.instruction->parent())) {
414       return false;
415     }
416   }
417 
418   // All uses of 'a' must be before 'b' is defined.
419   std::vector<const HloUse*> uses;
420   for (const HloUse& use : a.GetUses()) {
421     if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
422                                          use.instruction)) {
423       continue;
424     }
425     uses.push_back(&use);
426   }
427   if (!UsesBeforeValueDefinition(uses, b, dataflow,
428                                  use_is_always_before_def_in_same_instr)) {
429     VLOG(4) << "uses of " << a << "not before " << b << " is defined";
430     return false;
431   }
432 
433   if (a.IsRootOf(b.instruction()->parent())) {
434     VLOG(4) << a << " is live out of computation and defined before " << b
435             << " which is in same computation";
436     return false;
437   }
438 
439   return true;
440 }
441 
MayInterfere(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const442 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
443                                const HloDataflowAnalysis& dataflow) const {
444   // Buffers without disjoint liveness may interfere.
445   return !LiveRangeStrictlyBefore(a, b, dataflow) &&
446          !LiveRangeStrictlyBefore(b, a, dataflow);
447 }
448 
PredecessorHloOrdering(const HloModule * module)449 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
450     : HloOrdering(module) {}
451 
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const452 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
453     const HloInstruction* a, const HloInstruction* b) const {
454   CHECK_EQ(a->parent(), b->parent());
455 
456   // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
457   return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
458 }
459 
ToStringHelper(const std::string & name) const460 std::string PredecessorHloOrdering::ToStringHelper(
461     const std::string& name) const {
462   std::vector<std::string> pieces;
463   pieces.push_back(name);
464   for (auto* computation : module_->MakeNonfusionComputations()) {
465     pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
466     const auto all = computation->MakeInstructionPostOrder();
467     for (auto instruction : all) {
468       pieces.push_back(
469           absl::StrFormat("  %s predecessors:", instruction->name()));
470       for (auto predecessor : all) {
471         if (predecessors_.at(computation)
472                 ->IsReachable(predecessor, instruction)) {
473           pieces.push_back(absl::StrFormat("    %s", predecessor->name()));
474         }
475       }
476     }
477   }
478   return absl::StrJoin(pieces, "\n");
479 }
480 
DependencyHloOrdering(const HloModule * module)481 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
482     : PredecessorHloOrdering(module) {
483   // Compute predecessor relationships between all instructions to determine
484   // ordering based on dependencies. ExecutesBefore will return true iff there
485   // exists a path in the HLO computation graph from 'a' to 'b'.
486   for (auto* computation : module->MakeNonfusionComputations()) {
487     predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
488   }
489 }
490 
ToString() const491 std::string DependencyHloOrdering::ToString() const {
492   return ToStringHelper("DependencyHloOrdering");
493 }
494 
SequentialHloOrdering(const HloSchedule & schedule)495 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
496     : HloOrdering(schedule.module()), schedule_(schedule) {
497   Initialize();
498 }
499 
SequentialHloOrdering(HloSchedule && schedule)500 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
501     : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
502   Initialize();
503 }
504 
Initialize()505 void SequentialHloOrdering::Initialize() {
506   // Create a map from instruction to its order position.
507   TF_DCHECK_OK(schedule_.Verify());
508   for (const auto& computation_sequence : schedule_.sequences()) {
509     const auto& order = computation_sequence.second.instructions();
510     for (int i = 0; i < order.size(); ++i) {
511       InsertOrDie(&order_position_, order[i], i);
512     }
513   }
514 }
515 
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const516 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
517     const HloInstruction* a, const HloInstruction* b) const {
518   CHECK_EQ(a->parent(), b->parent());
519   // If either instruction is not in the order, then 'a' and 'b' are unordered.
520   if (!order_position_.contains(a) || !order_position_.contains(b)) {
521     return false;
522   }
523   if (a->parent()->root_instruction() == a) {
524     // 'a' is the root instruction of the computation, which lives out. So
525     // 'a' cannot execute before 'b'.
526     return false;
527   }
528   return order_position_.at(a) < order_position_.at(b);
529 }
530 
SequentialOrder(const HloComputation & computation) const531 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
532     const HloComputation& computation) const {
533   return schedule_.is_computation_scheduled(&computation)
534              ? &schedule_.sequence(&computation)
535              : nullptr;
536 }
537 
ToString() const538 std::string SequentialHloOrdering::ToString() const {
539   return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
540 }
541 
542 }  // namespace xla
543