• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include <optional>
19 #include <sstream>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/any.h"
27 #include "tensorflow/compiler/xla/service/dump.h"
28 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_dce.h"
31 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
36 #include "tensorflow/compiler/xla/service/logical_buffer.h"
37 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/platform/logging.h"
43 
44 namespace xla {
45 namespace {
46 
47 using absl::StrAppend;
48 
IsReadonlyEntryParameterValue(const HloValue & value)49 bool IsReadonlyEntryParameterValue(const HloValue& value) {
50   const HloComputation* computation = value.defining_instruction()->parent();
51   return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
52          computation == computation->parent()->entry_computation() &&
53          !computation->parent()->input_output_alias_config().ParameterHasAlias(
54              value.defining_instruction()->parameter_number(), value.index());
55 }
56 
IsConstantValue(const HloValue & value)57 bool IsConstantValue(const HloValue& value) {
58   return value.defining_instruction()->opcode() == HloOpcode::kConstant;
59 }
60 
ValueIsReadOnly(const HloValue & value)61 bool ValueIsReadOnly(const HloValue& value) {
62   return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
63 }
64 
65 // Data structure describing the action which should be taken on parts of a
66 // computation buffers, with respect to the adding of special case copies.
67 struct SpecialCaseCopyPolicy {
68   // Insert a copy if the same buffer is found at multiple indices within the
69   // output tuple.
70   bool copy_root_replicated_buffers = false;
71   // If true, insert a copy if a buffer coming from a constant or a parameter
72   // is found within the output tuple.
73   bool copy_parameters_and_constants = false;
74 };
75 
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)76 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
77                                                HloModule* module,
78                                                HloComputation* computation) {
79   SpecialCaseCopyPolicy policy;
80   if (computation == module->entry_computation()) {
81     policy.copy_parameters_and_constants = true;
82     policy.copy_root_replicated_buffers = true;
83   }
84   return policy;
85 }
86 
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)87 bool ShouldCopyRootValue(const HloValue& value,
88                          const SpecialCaseCopyPolicy& policy) {
89   if (policy.copy_parameters_and_constants) {
90     return ValueIsReadOnly(value);
91   }
92   return false;
93 }
94 
95 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
96 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
97 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
98 // of 'to'.
99 //
100 // Requirements: 'from' and 'to' must have compatible shapes.
101 //
102 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
103 // the only index to copy. Prior to deep-copying we have:
104 //
105 //
106 //      'from'
107 //         |
108 //        ...
109 //         |
110 //       'to'
111 //
112 // DeepCopyAndAddControlEdges produces:
113 //
114 //       'from'
115 //        /   \
116 //      GTE   GTE
117 //       |     |
118 //     Copy    |
119 //    /   \   /
120 //   |    Tuple
121 //   |      |
122 //  ctrl   ...
123 //  edge    |
124 //   |      |
125 //   |    'to'
126 //   |    /   \
127 //   |  GTE   GTE
128 //    \  |     |
129 //     Copy    |
130 //        \   /
131 //        Tuple
132 //
133 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)134 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
135                            const ShapeTree<bool>& indices_to_copy) {
136   DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
137   // to/from_copy_tree hold the kCopy instruction produces by the deep
138   // copies. Elements which are not copied (indices_to_copy.element(index) ==
139   // false) have nullptr at that index.
140   ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
141                                             /*init_value=*/nullptr);
142   TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
143                       from->parent()->DeepCopyInstruction(
144                           from, &indices_to_copy, &from_copy_tree));
145 
146   ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
147   TF_ASSIGN_OR_RETURN(
148       HloInstruction * to_deep_copy,
149       to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
150 
151   // Add control edges between the respective kCopy instructions.
152   for (const auto& pair : from_copy_tree) {
153     const ShapeIndex& index = pair.first;
154     HloInstruction* from_copy = pair.second;
155     HloInstruction* to_copy = to_copy_tree.element(index);
156     if (from_copy == nullptr) {
157       TF_RET_CHECK(to_copy == nullptr);
158       continue;
159     }
160     TF_RET_CHECK(to_copy != nullptr);
161     TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
162   }
163 
164   return std::make_pair(from_deep_copy, to_deep_copy);
165 }
166 
167 // Compute the indices of the loop state which need copies in order to avoid
168 // live range interference. Generally, an element in the loop state does not
169 // need to be copied if the element is passed through transparently through the
170 // body.
171 //
172 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)173 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
174                            const HloInstruction* xla_while,
175                            ShapeTree<bool>* indices_to_copy) {
176   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
177 
178   bool any_copies = false;
179   const HloInstruction* init = xla_while->operand(0);
180   for (auto& pair : *indices_to_copy) {
181     const ShapeIndex& index = pair.first;
182     bool& should_copy = pair.second;
183     // If there is any ambiguity, then loop state must be copied.
184     if (dataflow.GetValueSet(init, index).values().size() > 1 ||
185         dataflow.GetValueSet(xla_while, index).values().size() > 1) {
186       should_copy = true;
187     } else {
188       // If the output of the while instruction is not the same as the init
189       // value of the while, then this element is not passed through the body
190       // transparently and must be copied.
191       should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
192                     dataflow.GetUniqueValueAt(init, index);
193     }
194     any_copies |= should_copy;
195   }
196   return any_copies;
197 }
198 
199 // Compute the indices of the conditional outputs which need copies. Umambiguous
200 // buffers(buffer with only one value) don't need copies.
IndicesToCopyForConditional(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_conditional,ShapeTree<bool> * indices_to_copy)201 bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow,
202                                  const HloInstruction* xla_conditional,
203                                  ShapeTree<bool>* indices_to_copy) {
204   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(),
205                                xla_conditional->shape()));
206 
207   bool any_copies = false;
208   for (auto& pair : *indices_to_copy) {
209     const ShapeIndex& index = pair.first;
210     bool& should_copy = pair.second;
211 
212     CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1);
213 
214     auto value = dataflow.GetValueSet(xla_conditional, index).values()[0];
215     // The conditional must be copied if the value is a phi.
216     should_copy =
217         value->is_phi() && value->defining_instruction() == xla_conditional;
218     any_copies |= should_copy;
219   }
220   return any_copies;
221 }
222 
223 // Add kCopy instructions around the given kWhile instruction to eliminate any
224 // possible live range interference of HLO values assuming a dependency-based
225 // ordering. Copies are added conservatively. There  likely are copies which are
226 // not strictly necessary, but they are removed later in the pass via
227 // RemoveUnnecessaryCopies.
228 //
229 // Elements (each ShapeIndex) in the loop state are considered independently.  A
230 // copy is added to each element of the loop state which is modified in the
231 // while body. For each such element, a total of three kCopy instructions are
232 // added at following locations:
233 //
234 //   (1) The init value is copied before the kWhile instruction. Before:
235 //
236 //           (Init)
237 //             |
238 //           kWhile
239 //             |
240 //            ...
241 //
242 //       After:
243 //
244 //           (Init)
245 //             |
246 //           kCopy
247 //             |
248 //           kWhile
249 //             |
250 //            ...
251 //
252 //       This copy is necessary in case the init value is simultaneously live
253 //       with the kWhile.
254 //
255 //   (2) Copies are added to the parameter and root of the while body
256 //       computation. Before:
257 //
258 //           kParameter
259 //               |
260 //              ...
261 //               |
262 //           (body root)
263 //
264 //       After:
265 //
266 //           kParameter
267 //               |
268 //             kCopy ----------+
269 //               |             |
270 //              ...           ctrl
271 //               |            edge
272 //           (body root)       |
273 //               |             |
274 //             kCopy <---------+
275 //
276 //       The root kCopy becomes the new root of the computation. Both copies are
277 //       necessary to any potential interference between the parameter value and
278 //       the root value. The control edge prevents potential interference
279 //       between the copies themselves.
280 //
281 // If the loop state is a tuple then the above kCopy instructions are a deep
282 // copy constructed of kCopy, kGetTupleElement, and kTuple instruction as
283 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)284 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
285                          HloInstruction* xla_while) {
286   VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
287   TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
288 
289   ShapeTree<bool> indices_to_copy(xla_while->shape());
290   if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
291                              &indices_to_copy)) {
292     VLOG(2) << "No copies necessary for kWhile instruction "
293             << xla_while->name();
294     return Status::OK();
295   }
296 
297   VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
298   for (auto& pair : indices_to_copy) {
299     if (pair.second) {
300       VLOG(2) << "  " << pair.first;
301     }
302   }
303 
304   // Deep copy init.
305   HloInstruction* while_init = xla_while->mutable_operand(0);
306   TF_ASSIGN_OR_RETURN(
307       HloInstruction * while_init_copy,
308       xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
309   TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
310 
311   // Deep copy the parameter and the root. Extend a control edge from the copy
312   // of the parameter value to the corresponding copy value of the root.
313   HloComputation* body = xla_while->while_body();
314   HloInstruction* param = body->parameter_instruction(0);
315   HloInstruction* root = body->root_instruction();
316 
317   // If param is the root then all indices should have been passed through the
318   // while body and we should have returned early above.
319   TF_RET_CHECK(param != root);
320 
321   // Copy users before making a deep copy of the parameter as the deep copy
322   // will create new users of the parameter (eg, the GTE instructions of the
323   // deep copy).
324   std::vector<HloInstruction*> param_users = param->users();
325 
326   TF_ASSIGN_OR_RETURN(auto pair,
327                       DeepCopyAndAddControlEdges(param, root, indices_to_copy));
328 
329   HloInstruction* param_copy = pair.first;
330   HloInstruction* root_copy = pair.second;
331 
332   for (HloInstruction* user : param_users) {
333     TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
334   }
335 
336   body->set_root_instruction(root_copy);
337   return Status::OK();
338 }
339 
340 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
341 // will remove the unnecessary copies.
AddCopiesForInPlaceOperation(const HloAliasAnalysis & alias_analysis,HloInstruction * in_place_op,int64_t operand_number)342 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
343                                     HloInstruction* in_place_op,
344                                     int64_t operand_number) {
345   VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
346   HloInstruction* operand = in_place_op->mutable_operand(operand_number);
347   TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
348                       in_place_op->parent()->DeepCopyInstruction(operand));
349   TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
350   return Status::OK();
351 }
352 
353 // Conservatively adds copies before root instruction of entry computation and
354 // each aliased parameter to resolve interference of aliased input and output
355 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
356 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)357 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
358   HloComputation* entry = module->entry_computation();
359   HloInstruction* root = entry->root_instruction();
360 
361   ShapeTree<bool> output_indices_to_copy(root->shape());
362   std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
363       entry->num_parameters());
364   bool has_alias = false;
365   for (auto* param : entry->parameter_instructions()) {
366     bool param_has_alias = false;
367     ShapeTree<bool> param_indices_to_copy(param->shape());
368 
369     module->input_output_alias_config().ForEachAlias(
370         [&](const ShapeIndex& output_index,
371             const HloInputOutputAliasConfig::Alias& alias) {
372           if (alias.parameter_number == param->parameter_number()) {
373             param_has_alias = true;
374             *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
375                 true;
376             *(output_indices_to_copy.mutable_element(output_index)) = true;
377           }
378         });
379 
380     if (!param_has_alias) {
381       continue;
382     }
383 
384     TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
385     TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
386 
387     has_alias = true;
388     // Store a snapshot of users before DeepCopyInstruction, as
389     // DeepCopyInstruction introduces new users of the instruction.
390     std::vector<HloInstruction*> users = param->users();
391     ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
392                                                /*init_value=*/nullptr);
393     TF_ASSIGN_OR_RETURN(HloInstruction * copied,
394                         entry->DeepCopyInstruction(
395                             param, &param_indices_to_copy, &param_copy_tree));
396     for (HloInstruction* user : users) {
397       TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
398     }
399 
400     copied_parameters[param->parameter_number()] = param_copy_tree;
401   }
402 
403   if (!has_alias) {
404     return Status::OK();
405   }
406 
407   // Add copies before root instruction.
408   ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
409                                               /*init_value=*/nullptr);
410 
411   TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
412                       root->parent()->DeepCopyInstruction(
413                           root, &output_indices_to_copy, &output_copy_tree));
414 
415   // Add control dependencies between the input/output copies.
416   TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
417       [&](const ShapeIndex& output_index,
418           const HloInputOutputAliasConfig::Alias& alias) -> Status {
419         if (!copied_parameters[alias.parameter_number]) {
420           return Status::OK();
421         }
422         HloInstruction* from =
423             copied_parameters[alias.parameter_number]->element(
424                 alias.parameter_index);
425         HloInstruction* to = output_copy_tree.element(output_index);
426 
427         TF_RET_CHECK(from != nullptr);
428         TF_RET_CHECK(to != nullptr);
429         TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
430         return Status::OK();
431       }));
432 
433   entry->set_root_instruction(root_copied);
434 
435   return Status::OK();
436 }
437 
438 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)439 Status StripControlDependenciesFrom(HloInstruction* instruction) {
440   while (!instruction->control_successors().empty()) {
441     TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
442         instruction->control_successors().front()));
443   }
444 
445   while (!instruction->control_predecessors().empty()) {
446     TF_RETURN_IF_ERROR(
447         instruction->control_predecessors().front()->RemoveControlDependencyTo(
448             instruction));
449   }
450 
451   return Status::OK();
452 }
453 
454 class LiveRangeRegions {
455  public:
456   struct InstructionInfo {
InstructionInfoxla::__anonb2d115a80111::LiveRangeRegions::InstructionInfo457     InstructionInfo() : value_definition(nullptr), is_definition(false) {}
458 
459     // The instruction that defines the value being used. It basically saves
460     // the defining instruction of each HloValue.
461     HloInstruction* value_definition;
462     // Whether the instruction defines a new value (or merely uses one). This
463     // basically remembers whether the instruction actually creates an HloValue
464     // or merely uses one, from a collection of given HloValues. Note that if
465     // is_definition = true, it merely says the instruction creates a new
466     // HloValue with or without defining a new one. For example, kAdd create a
467     // new HloValue (can be value_definition), but tuples or get-tuple-element,
468     // create a new HloValue aliasing without defining a new value (cannot be
469     // value_definition).
470     bool is_definition;
471   };
472   // Map instructions that use a value to the defining instruction of the value.
473   // Because all values must belong to the same live range, an instruction can
474   // have at most a single value-defining instruction; otherwise the multiple
475   // incoming active values would share a single buffer, which is not allowed.
476   // The value-defining and value-use instructions do not have to belong to the
477   // same computation, but the value use needs to be nested within the defining
478   // computation.
479   typedef absl::flat_hash_map<HloInstruction*, InstructionInfo> InstructionMap;
480   typedef std::pair<HloInstruction*, InstructionInfo> InstructionEntry;
481   // Map each computation to its immediately contained instructions.
482   typedef absl::flat_hash_map<const HloComputation*, InstructionMap>
483       ComputationMap;
484 
operator [](const HloComputation * computation)485   InstructionMap& operator[](const HloComputation* computation) {
486     if (computation_map_.find(computation) == computation_map_.end()) {
487       computation_vector_.push_back(computation);
488     }
489     return computation_map_[computation];
490   }
491 
operator [](const HloComputation * computation) const492   const InstructionMap& operator[](const HloComputation* computation) const {
493     ComputationMap::const_iterator p = computation_map_.find(computation);
494     CHECK(p != computation_map_.end());
495     return p->second;
496   }
begin() const497   ComputationMap::const_iterator begin() const {
498     return computation_map_.begin();
499   }
end() const500   ComputationMap::const_iterator end() const { return computation_map_.end(); }
size() const501   int64 size() const {
502     CHECK_EQ(computation_vector_.size(), computation_map_.size());
503     return computation_vector_.size();
504   }
empty() const505   bool empty() const { return size() == 0; }
Computation(int64_t index) const506   const HloComputation* Computation(int64_t index) const {
507     return computation_vector_[index];
508   }
contains(const HloInstruction * instr) const509   bool contains(const HloInstruction* instr) const {
510     CHECK_NE(instr, nullptr);
511     auto* computation = instr->parent();
512     auto p = computation_map_.find(computation);
513     if (p == computation_map_.end()) {
514       return false;
515     }
516     auto instr_map = (*p).second;
517     return instr_map.find(instr) != instr_map.end();
518   }
519 
520  private:
521   ComputationMap computation_map_;
522   absl::InlinedVector<const HloComputation*, 5> computation_vector_;
523 };
524 
525 namespace {
526 // Represent relations between the locations of two regions of instructions,
527 // each region can include 0-n instructions.
528 class Relation {
529  public:
530   enum RuntimeOrder {
531     // Indicate that there is no overlap whatsoever between the two regions.
532     kNoOverlap = 0,
533     // Indicate that the first region includes the same set of instructions as
534     // the second region.
535     kSameInstr = 1,
536     // Indicate that the first region is entirely before the second region
537     // starts.
538     kBeforeStart = 2,
539     // Indicate that the first region is before the second region ends.
540     kBeforeStartOrSameInstr = kBeforeStart | kSameInstr,
541     // Indicate that the first region is entirely after the second region ends.
542     kAfterEnd = 4,
543     // Indicate that the first region is after the second region
544     // starts, with some instructions before the second region ends.
545     kAfterEndOrSameInstr = kAfterEnd | kSameInstr,
546     // Indicate that the first region overlaps with the second one, but share no
547     // common instructions.
548     kBeforeStartOrAfterEnd = kBeforeStart | kAfterEnd,
549     // Indicate that the first region overlaps with the second one, and have
550     // some common instructions.
551     kBeforeOrAfterOrOverlap = kBeforeStart | kAfterEnd | kSameInstr,
552   };
Relation()553   Relation() : intercept_def_use_(false) {}
Relation(RuntimeOrder order,bool intercept_def_use=false)554   explicit Relation(RuntimeOrder order, bool intercept_def_use = false)
555       : intercept_def_use_(intercept_def_use) {
556     orders_.push_back(order);
557   }
Relation(const Relation & that)558   Relation(const Relation& that)
559       : intercept_def_use_(that.intercept_def_use_), orders_(that.orders_) {}
operator ==(const Relation & that) const560   bool operator==(const Relation& that) const {
561     return intercept_def_use_ == that.intercept_def_use_ &&
562            absl::c_equal(orders_, that.orders_);
563   }
564 
565   // Return whether the runtime ordering may imply interception, assuming it
566   // models the relation between a modifying and a use instruction.
UseImpliesInterception() const567   bool UseImpliesInterception() const {
568     CHECK_EQ(orders_.size(), 1);
569     return UseImpliesInterception(orders_[0]);
570   }
571   // Return whether the runtime ordering may imply interception, assuming it
572   // models the relation between a modifying and a definition instruction.
DefinitionImpliesInterception() const573   bool DefinitionImpliesInterception() const {
574     CHECK_EQ(orders_.size(), 1);
575     return DefinitionImpliesInterception(orders_[0]);
576   }
577   // Return whether the current relation models a modifying instruction that
578   // intercepts the dataflow of another live range region.
InterceptDefUse() const579   bool InterceptDefUse() const { return intercept_def_use_; }
580   // Update interception state to the given value.
UpdateInterception(bool value)581   void UpdateInterception(bool value) {
582     CHECK_EQ(orders_.size(), 1);
583     intercept_def_use_ = value;
584   }
GetRuntimeOrder() const585   Relation::RuntimeOrder GetRuntimeOrder() const {
586     if (orders_.empty()) {
587       return Relation::kNoOverlap;
588     }
589     CHECK_EQ(orders_.size(), 1);
590     return orders_[0];
591   }
592   // Return whether the current relation implies two overlapping regions.
RuntimeOrderOverlap() const593   bool RuntimeOrderOverlap() const {
594     return absl::c_any_of(orders_, ImpliesOverlap);
595   }
RuntimeOrderIsUnordered() const596   bool RuntimeOrderIsUnordered() const {
597     return orders_.size() == 1 && orders_[0] == kBeforeStartOrAfterEnd;
598   }
RuntimeOrderIsNoOverlap() const599   bool RuntimeOrderIsNoOverlap() const {
600     return orders_.empty() || (orders_.size() == 1 && orders_[0] == kNoOverlap);
601   }
RuntimeOrderIsRunBefore() const602   bool RuntimeOrderIsRunBefore() const {
603     return orders_.size() == 1 && orders_[0] == kBeforeStart;
604   }
RuntimeOrderIsRunAfter() const605   bool RuntimeOrderIsRunAfter() const {
606     return orders_.size() == 1 && orders_[0] == kAfterEnd;
607   }
ToString() const608   std::string ToString() const {
609     return absl::StrCat("Interception = ", intercept_def_use_, ";",
610                         absl::StrJoin(orders_, ","));
611   }
612 
DefinitionImpliesInterception(RuntimeOrder definition)613   static bool DefinitionImpliesInterception(RuntimeOrder definition) {
614     return (definition == kAfterEnd || definition == kBeforeStartOrAfterEnd);
615   }
UseImpliesInterception(RuntimeOrder use)616   static bool UseImpliesInterception(RuntimeOrder use) {
617     return (use == kBeforeStart || use == kBeforeStartOrAfterEnd);
618   }
619 
620   // Summarize additional relations into a single runtime ordering, assuming
621   // both relations are modeling constraints of the same source instruction.
UnionRelationFromSameSource(const Relation & rel)622   void UnionRelationFromSameSource(const Relation& rel) {
623     CHECK_LE(orders_.size(), 1);
624     CHECK_EQ(rel.orders_.size(), 1);
625     if (orders_.empty()) {
626       orders_.push_back(rel.orders_[0]);
627     } else {
628       orders_[0] = Union(orders_[0], rel.orders_[0]);
629     }
630     intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
631   }
632 
633   // Summarize additional relations into disjoint runtime orderings, assuming
634   // the relations are modeling constraints of different source instructions.
UnionRelationFromDifferentSource(const Relation & rel)635   void UnionRelationFromDifferentSource(const Relation& rel) {
636     if (rel.orders_.empty()) {
637       return;
638     }
639     CHECK_EQ(rel.orders_.size(), 1);
640     intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
641     for (auto& local_order : orders_) {
642       if (OverwriteIfSubsume(rel.orders_[0], &local_order)) {
643         return;
644       }
645     }
646     orders_.push_back(rel.orders_[0]);
647   }
648 
ReverseRuntimeOrder(RuntimeOrder order)649   static Relation::RuntimeOrder ReverseRuntimeOrder(RuntimeOrder order) {
650     switch (order) {
651       case kNoOverlap:
652       case kSameInstr:
653       case kBeforeStartOrAfterEnd:
654       case kBeforeOrAfterOrOverlap:
655         return order;
656       case kBeforeStart:
657         return kAfterEnd;
658       case kBeforeStartOrSameInstr:
659         return kAfterEndOrSameInstr;
660       case kAfterEnd:
661         return kBeforeStart;
662       case kAfterEndOrSameInstr:
663         return kBeforeStartOrSameInstr;
664     }
665   }
666 
667  private:
668   // Indicate that the second region may intercept the def-use dataflow of the
669   // first region, if their buffers are combined.
670   bool intercept_def_use_;
671   // Remember the different runtime orderings of different instructions.
672   absl::InlinedVector<RuntimeOrder, 4> orders_;
673 
Union(RuntimeOrder o1,RuntimeOrder o2)674   static RuntimeOrder Union(RuntimeOrder o1, RuntimeOrder o2) {
675     return static_cast<Relation::RuntimeOrder>(o1 | o2);
676   }
ImpliesOverlap(RuntimeOrder o)677   static bool ImpliesOverlap(RuntimeOrder o) {
678     return o >= RuntimeOrder::kBeforeStartOrAfterEnd;
679   }
680   // Returns whether ordering constraint o1 includes o2 as a subset, when they
681   // represent runtime orderings (interleavings) of two different regions.
Subsume(RuntimeOrder o1,RuntimeOrder o2)682   static bool Subsume(RuntimeOrder o1, RuntimeOrder o2) {
683     return Union(o1, o2) == o1;
684   }
685   // Overwrites o1 with o2 if o2 subsumes o1 (as defined above by the Subsume
686   // function). Return whether o2 is subsumed by the new value in o1.
OverwriteIfSubsume(RuntimeOrder o2,RuntimeOrder * o1)687   static bool OverwriteIfSubsume(RuntimeOrder o2, RuntimeOrder* o1) {
688     if (*o1 == o2) {
689       return true;
690     }
691     CHECK_NE(o1, nullptr);
692     // Overwrite o1 with o2 if it is subsumed by o2.
693     if (Subsume(o2, *o1)) {
694       *o1 = o2;
695       return true;
696     } else if (Subsume(*o1, o2)) {
697       // If o2 is already subsumed by o1, do nothing.
698       return true;
699     }
700     // If neither o1 nor o2 is subsumed by the other, return false, so that o2
701     // will be inserted as a separate entry representing all possible orderings.
702     return false;
703   }
704 };
705 
706 class ComputeRelativeLocation {
707  public:
708   typedef LiveRangeRegions::InstructionEntry InstructionEntry;
ComputeRelativeLocation(HloOrdering * ordering)709   explicit ComputeRelativeLocation(HloOrdering* ordering)
710       : ordering_(ordering) {
711     VLOG(3) << "New analysis\n";
712   }
713 
714   // Compute locationing constraints between two instructions. Here entry2 is
715   // the source instruction, in that the returned value describes the relation
716   // of entry2 in terms of whether it is before or after entry1, and whether it
717   // can intercept the def-use data flow of entry1.
Compute(const InstructionEntry & entry1,const InstructionEntry & entry2,bool instr2_can_modify)718   Relation Compute(const InstructionEntry& entry1,
719                    const InstructionEntry& entry2, bool instr2_can_modify) {
720     auto def = entry1.second.value_definition;
721     auto use = entry1.first;
722     Relation::RuntimeOrder order =
723         ComputeRuntimeOrdering(entry2.first, entry1.first);
724     if (order == Relation::kSameInstr &&
725         entry1.second.is_definition != entry2.second.is_definition) {
726       if (entry1.second.is_definition) {
727         order = Relation::kBeforeStart;
728       } else {
729         order = Relation::kAfterEnd;
730       }
731     }
732     bool intercept = AlwaysForceInterception(entry2.first);
733     if (def == nullptr || !instr2_can_modify) {
734       return Relation(order, intercept);
735     }
736     // If the definition and use are parameter and return (root) of the parent
737     // computation, then any modification is considered intercepting.
738     if (def->opcode() == HloOpcode::kParameter &&
739         use == use->parent()->root_instruction()) {
740       VLOG(3) << "Setting interception due to parameter/root relation\n";
741       return Relation(order, true);
742     }
743     if (Relation::UseImpliesInterception(order)) {
744       auto order2 = ComputeRuntimeOrdering(entry2.first, def);
745       if (Relation::DefinitionImpliesInterception(order2)) {
746         VLOG(3) << "Setting interception for " << def->ToString()
747                 << " with use:" << entry1.first->ToString() << "\n";
748         intercept = true;
749       }
750     }
751     return Relation(order, intercept);
752   }
753 
754   // Return the relative locations (defined above) of range2 in relation to
755   // instructions in range1. Return kNoOverlap if range2 is outside of range1.
Compute(const LiveRangeRegions & range1,const LiveRangeRegions & range2)756   Relation Compute(const LiveRangeRegions& range1,
757                    const LiveRangeRegions& range2) {
758     Relation dir_src_dest;
759     for (int64_t index = 0; index < range1.size(); index++) {
760       auto* computation1 = range1.Computation(index);
761       for (const auto& computation_entry2 : range2) {
762         auto* computation2 = computation_entry2.first;
763         for (auto instr_entry2 : computation_entry2.second) {
764           if (!ordering_->call_graph().Dominates(computation1, computation2)) {
765             continue;
766           }
767           VLOG(3) << "Locationing " << instr_entry2.first->ToString();
768           // Saves relations between instr2 and other instructions in range1.
769           bool instr2_can_modify =
770               InstructionCanIntercept(instr_entry2, range1);
771           Relation instr2_relation;
772           std::vector<InstructionEntry> unordered_ops;
773           bool unordered_intercept = false;
774           for (auto instr_entry1 : range1[computation1]) {
775             auto rel = Compute(instr_entry1, instr_entry2, instr2_can_modify);
776             VLOG(3) << "new relation with:" << instr_entry1.first->ToString()
777                     << " = " << rel.ToString() << "\n";
778             if (!rel.RuntimeOrderIsUnordered()) {
779               instr2_relation.UnionRelationFromSameSource(rel);
780             } else {
781               unordered_ops.push_back(instr_entry1);
782               unordered_intercept |= rel.InterceptDefUse();
783             }
784             VLOG(3) << "instr2 relation:" << instr2_relation.ToString() << "\n";
785           }
786           // Here instru2_relation is guaranteed to have at most a single entry,
787           // because it was initialized to be empty, and has been updated only
788           // via instr2_relation.UnionRelationFromSameSource(rel), which
789           // maintains that the updated result has only a single entry.
790           if (!ForceRuntimeOrder(unordered_ops, instr_entry2,
791                                  instr2_relation.GetRuntimeOrder())) {
792             VLOG(3) << "Unable to force ordering of unordered ops\n";
793             instr2_relation.UnionRelationFromSameSource(Relation(
794                 Relation::kBeforeStartOrAfterEnd, unordered_intercept));
795           }
796           dir_src_dest.UnionRelationFromDifferentSource(instr2_relation);
797           VLOG(3) << "Resulting relation : " << dir_src_dest.ToString() << "\n";
798         }
799       }
800     }
801     return dir_src_dest;
802   }
803 
804   // Return whether control dependences, if exist, are added successfully.
AddControlDependenceForUnorderedOps()805   bool AddControlDependenceForUnorderedOps() {
806     if (ctrl_deps_.empty()) {
807       return true;
808     }
809     PredecessorHloOrdering* ordering =
810         dynamic_cast<PredecessorHloOrdering*>(ordering_);
811     if (ordering == nullptr) {
812       // Support force ordering of unordered-ops only when using predecssor
813       // ordering.
814       return false;
815     }
816     for (const auto& comp_it : ctrl_deps_) {
817       HloComputation* parent = comp_it.first;
818       HloReachabilityMap& reachability_map = ordering->reachability_map(parent);
819       for (const auto& instr_it : comp_it.second) {
820         HloInstruction* entry1 = instr_it.first;
821         for (HloInstruction* entry2 : instr_it.second) {
822           VLOG(3) << "Add control dependence between " << entry2->ToString();
823           VLOG(3) << "\n vs " << entry1->ToString() << "\n";
824           TF_CHECK_OK(entry2->AddControlDependencyTo(entry1));
825         }
826         reachability_map.UpdateReachabilityThroughInstruction(entry1);
827         for (HloInstruction* entry2 : instr_it.second) {
828           DCHECK(ordering_->GetExecutionConstraint(entry1, entry2) ==
829                  HloOrdering::ExecutionConstraint::kRunAfter);
830         }
831       }
832     }
833     return true;
834   }
835 
836  private:
837   enum ComputeStatus {
838     kFullyComputed,
839     kPartiallyComputed,
840     kNotComputed,
841   };
842   typedef std::pair<ComputeStatus, Relation::RuntimeOrder> SavedRelation;
843 
844   // Returns whether it is safe to force the desired_relation ordering between
845   // all operations in unordered_ops and entry2. If safe, save the new enforced
846   // ordering relations.
ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,const InstructionEntry entry2,Relation::RuntimeOrder desired_relation)847   bool ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,
848                          const InstructionEntry entry2,
849                          Relation::RuntimeOrder desired_relation) {
850     if (unordered_ops.empty()) {
851       return true;
852     }
853     if (desired_relation != Relation::kBeforeStart &&
854         desired_relation != Relation::kAfterEnd) {
855       return false;
856     }
857     auto ModifiesNonCopy = [](HloInstruction* instr, const HloInstruction* op) {
858       auto in_place = HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr);
859       if (in_place.empty()) {
860         return false;
861       }
862       return absl::c_any_of(
863           in_place, [&](const std::pair<HloUse, ShapeIndex>& a) {
864             auto* op2 = instr->operand(a.first.operand_number);
865             return (op == nullptr) ? (op2->opcode() == HloOpcode::kCopy)
866                                    : (op2 == op);
867           });
868     };
869     for (const InstructionEntry& entry1 : unordered_ops) {
870       // Only consider instructions in the same computation.
871       if (entry1.first->parent() != entry2.first->parent()) {
872         return false;
873       }
874       HloInstruction* pred = (desired_relation == Relation::kBeforeStart)
875                                  ? entry2.first
876                                  : entry1.first;
877       HloInstruction* succ = (desired_relation == Relation::kBeforeStart)
878                                  ? entry1.first
879                                  : entry2.first;
880       if (pred == pred->parent()->root_instruction()) {
881         return false;
882       }
883       if (succ->opcode() == HloOpcode::kCopy &&
884           ModifiesNonCopy(pred, succ->operand(0))) {
885         VLOG(3) << "Failed to force unordered op ordering due to copy ordering "
886                 << " between " << pred->ToString() << "\n";
887         VLOG(3) << " vs. " << succ->ToString() << "\n";
888         return false;
889       }
890     }
891     for (const InstructionEntry& entry1 : unordered_ops) {
892       Save(entry2.first, entry1.first, desired_relation, true);
893     }
894     return true;
895   }
896 
AlwaysForceInterception(HloInstruction * instr)897   static bool AlwaysForceInterception(HloInstruction* instr) {
898     // The following communication operations can have some unexpected side
899     // effects, when synchronizing across processes. Therefore, we
900     // conservatively try provide dedicated buffers to these operations instead
901     // of allowing them to share buffers with other operations, as the reuse may
902     // cause unexpected interferences.
903     if (HloDataflowAnalysis::IsAsynchronousOperationStart(instr->opcode()) ||
904         HloDataflowAnalysis::IsAsynchronousOperationDone(instr->opcode())) {
905       return true;
906     }
907     switch (instr->opcode()) {
908       // TODO(b/190903339): It appears that collectivePermute needs to be
909       // followed by a copy when escaping through a computation root.
910       case HloOpcode::kCollectivePermute:
911         return true;
912       default:
913         return false;
914     }
915   }
916 
917   // Returns whether the given instr may intercept the def-use flow of another
918   // ongoing live range if its buffer is combined with the other live range.
919   // The function should return true if instr creates a new HloValue that could
920   // overwrite an existing HloValue in the combined buffer.
921   // More specifically, here we are looking for operations that create new
922   // values, e.g., add, subtract, in contrast to HLOs that merely create
923   // aliasings among existing values, e.g., tuple, get-tuple-element. Any of the
924   // new values created by operations such as add or subtract, when included as
925   // definition operations in a live range, are aliases of the buffer to be
926   // allocated to the live range and so are treated as they may be modifying the
927   // targeting buffer.
InstructionCanIntercept(const InstructionEntry & entry,const LiveRangeRegions & region)928   bool InstructionCanIntercept(const InstructionEntry& entry,
929                                const LiveRangeRegions& region) {
930     auto instr = entry.first;
931     if (!entry.second.is_definition) {
932       // If the instruction only uses the value, it can intercept only if it
933       // modifies the buffer in place.
934       return !HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr).empty();
935     }
936     switch (instr->opcode()) {
937       // If the copy instruction is used to connect two live range regions,
938       // it does not overwrite the combined buffer with new values.
939       case HloOpcode::kCopy:
940         // Checking the copy simply copies from the other live range with no
941         // layout conflicts.
942         if (region.contains(instr->operand(0)) &&
943             ShapeUtil::Equal(instr->shape(), instr->operand(0)->shape())) {
944           return false;  // Cannot intercept.
945         }
946         return true;
947       // The following operations merely create aliases among the HloValues.
948       case HloOpcode::kParameter:
949       case HloOpcode::kTuple:
950       case HloOpcode::kGetTupleElement:
951       // Here we consider all the compound operations (e.g., conditionals and
952       // while loops) as if they do not modify any HloValue, with the argument
953       // being that any value modifying operation contained inside will be
954       // considered separately to make sure the kIntercept relation being
955       // recorded as appropriate. Since the compound operations may or may not
956       // modify, not treating them as value modifying would make the algorithm
957       // less conservative.
958       case HloOpcode::kWhile:
959       case HloOpcode::kCall:
960       case HloOpcode::kConditional:
961       case HloOpcode::kTupleSelect:
962         return false;
963       default:
964         return true;
965     }
966     return true;
967   }
968 
AlreadyComputed(HloInstruction * op1,HloInstruction * op2)969   SavedRelation AlreadyComputed(HloInstruction* op1, HloInstruction* op2) {
970     auto p2 = saved_relations_.find(op2);
971     if (p2 != saved_relations_.end()) {
972       auto p1 = (*p2).second.find(op1);
973       if (p1 != (*p2).second.end()) {
974         return SavedRelation(kFullyComputed, (*p1).second);
975       }
976     }
977     p2 = saved_relations_.find(op1);
978     if (p2 != saved_relations_.end()) {
979       auto p1 = (*p2).second.find(op2);
980       if (p1 != (*p2).second.end()) {
981         return SavedRelation(kPartiallyComputed,
982                              Relation::ReverseRuntimeOrder((*p1).second));
983       }
984     }
985     return SavedRelation(kNotComputed, Relation::kNoOverlap);
986   }
987 
Save(HloInstruction * entry1,HloInstruction * entry2,const Relation::RuntimeOrder relation,bool is_unordered_originally=false)988   Relation::RuntimeOrder Save(HloInstruction* entry1, HloInstruction* entry2,
989                               const Relation::RuntimeOrder relation,
990                               bool is_unordered_originally = false) {
991     CHECK_EQ(AlreadyComputed(entry1, entry2).first, kNotComputed);
992     // Do not save unordered relations.
993     CHECK_NE(relation, Relation::kBeforeStartOrAfterEnd);
994     saved_relations_[entry2][entry1] = relation;
995     if (is_unordered_originally) {
996       CHECK(relation == Relation::kBeforeStart ||
997             relation == Relation::kAfterEnd)
998           << relation;
999       HloInstruction* pred =
1000           (relation == Relation::kBeforeStart) ? entry1 : entry2;
1001       HloInstruction* succ =
1002           (relation == Relation::kBeforeStart) ? entry2 : entry1;
1003       VLOG(3) << "Save unordered relation: " << pred->ToString() << "\n";
1004       VLOG(3) << " vs " << succ->ToString() << "\n";
1005       CHECK_EQ(succ->parent(), pred->parent());
1006       auto& dep_vec = ctrl_deps_[succ->parent()][succ];
1007       for (HloInstruction*& op : dep_vec) {
1008         auto rel = AlreadyComputed(pred, op);
1009         if (rel.first != kNotComputed) {
1010           if (rel.second == Relation::kAfterEnd) {
1011             op = pred;
1012           } else {
1013             CHECK(rel.second == Relation::kBeforeStart);
1014           }
1015           return relation;
1016         }
1017       }
1018       VLOG(2) << "Forcing unordered:" << pred->ToString() << "\n";
1019       VLOG(2) << " vs " << succ->ToString() << "\n";
1020       dep_vec.push_back(pred);
1021     }
1022     return relation;
1023   }
1024 
1025   // Compute the runtime ordering constraints between two instructions.
ComputeRuntimeOrdering(HloInstruction * instr1,HloInstruction * instr2)1026   Relation::RuntimeOrder ComputeRuntimeOrdering(HloInstruction* instr1,
1027                                                 HloInstruction* instr2) {
1028     auto saved_relation = AlreadyComputed(instr1, instr2);
1029     if (saved_relation.first != kNotComputed) {
1030       VLOG(3) << "Already computed between " << instr1->ToString() << "\n vs "
1031               << instr2->ToString() << "\n";
1032       return saved_relation.second;
1033     }
1034     auto constraint = ordering_->GetExecutionConstraint(instr1, instr2);
1035     switch (constraint) {
1036       case HloOrdering::ExecutionConstraint::kIsSame:
1037         return Save(instr1, instr2, Relation::kSameInstr);
1038       case HloOrdering::ExecutionConstraint::kRunBeforeEnd:
1039         return Save(instr1, instr2, Relation::kBeforeStartOrSameInstr);
1040       case HloOrdering::ExecutionConstraint::kRunBeforeStart:
1041         return Save(instr1, instr2, Relation::kBeforeStart);
1042       case HloOrdering::ExecutionConstraint::kRunAfter:
1043         return Save(instr1, instr2, Relation::kAfterEnd);
1044       case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
1045       case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
1046         return Save(instr1, instr2, Relation::kNoOverlap);
1047       case HloOrdering::ExecutionConstraint::kUnordered: {
1048         if (instr1->parent() != instr2->parent()) {
1049           return Relation::kBeforeStartOrAfterEnd;
1050         }
1051         auto ControlDependenceBefore = [&](HloInstruction* op1,
1052                                            HloInstruction* op2) {
1053           auto constraint = ComputeRuntimeOrdering(op1, op2);
1054           if (constraint == Relation::kBeforeStart ||
1055               constraint == Relation::kSameInstr ||
1056               constraint == Relation::kBeforeStartOrSameInstr) {
1057             return true;
1058           } else {
1059             return false;
1060           }
1061         };
1062         if (!ctrl_deps_.empty()) {
1063           auto ctrl_deps = ctrl_deps_[instr1->parent()];
1064           if (absl::c_any_of(ctrl_deps[instr2], [&](HloInstruction* pred2) {
1065                 return ControlDependenceBefore(instr1, pred2);
1066               })) {
1067             VLOG(2) << "control-dependent: " << instr1->ToString() << "\n";
1068             VLOG(2) << "vs " << instr2->ToString() << "\n";
1069             return Save(instr1, instr2, Relation::kBeforeStart);
1070           } else if (absl::c_any_of(
1071                          ctrl_deps[instr1], [&](HloInstruction* pred1) {
1072                            return ControlDependenceBefore(instr2, pred1);
1073                          })) {
1074             VLOG(2) << "control-dependent: " << instr2->ToString() << "\n";
1075             VLOG(2) << "vs " << instr1->ToString() << "\n";
1076             return Save(instr1, instr2, Relation::kAfterEnd);
1077           }
1078         }
1079         // Don't save the result for unordered operations, so they can be
1080         // refined later.
1081         return Relation::kBeforeStartOrAfterEnd;
1082       }
1083     }
1084   }
1085 
1086   HloOrdering* ordering_;
1087   absl::flat_hash_map<
1088       HloInstruction*,
1089       absl::flat_hash_map<HloInstruction*, Relation::RuntimeOrder>>
1090       saved_relations_;
1091   absl::flat_hash_map<
1092       HloComputation*,
1093       absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>>
1094       ctrl_deps_;
1095 };
1096 }  // namespace
1097 
1098 // Class which tracks the HLO values within each HLO buffer in the module
1099 // during copy removal.
1100 //
1101 // The values are held in a linked list where there is one list for each
1102 // buffer. Removing a copy instruction merges together the values in the
1103 // source buffer of the copy to the destination buffer of the copy. This class
1104 // tracks these value lists as copies are removed from the graph (and value
1105 // lists are merged).
1106 //
1107 // The CopyRemover object is initialized to match the state of
1108 // HloAliasAnalysis. However, as copies are removed this state diverges. The
1109 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
1110 // a fully updatable alias analysis is very slow.
1111 class CopyRemover {
1112  public:
1113   // The values held in a single HLO buffer are represented using a linked
1114   // list. An element type in this list is ValueNode.
1115   //
1116   // This linked list is hand-rolled to enable efficient splicing of lists
1117   // using only references to list elements without knowing which lists are
1118   // being spliced. std::list requires a reference to the list object to
1119   // splice.
1120   struct ValueNode {
ValueNodexla::__anonb2d115a80111::CopyRemover::ValueNode1121     explicit ValueNode(const HloValue* v) : value(v) {}
1122 
1123     const HloValue* value;
1124 
1125     // The uses are maintained outside of HloValue::uses() because
1126     // HloValue::uses() is not updatable (a fully updatable dataflow analysis
1127     // is slow).
1128     std::vector<const HloUse*> uses;
1129 
1130     // next/prev elements in the linked list. The list is circularly linked so
1131     // these values are never null for elements in the list.
1132     ValueNode* prev = nullptr;
1133     ValueNode* next = nullptr;
1134   };
1135 
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,HloOrdering * ordering,bool check_live_range_ordering)1136   CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
1137               HloOrdering* ordering, bool check_live_range_ordering)
1138       : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
1139     // Construct a list for each HLO buffer in the alias analysis. Maintain a
1140     // map from HloValue to the respective list element representing that
1141     // value. The map is used to construct the copy info map below.
1142     absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
1143     // Perform check only if the default dependence-based ordering is used.
1144     for (const HloBuffer& buffer : alias_analysis.buffers()) {
1145       // No copies should have been inserted within fused computations, so no
1146       // need to remove them. HloOrdering isn't compatible with HloValues inside
1147       // fusions, so skip copy removal for them.
1148       if (buffer.values().at(0)->defining_instruction()->IsFused()) {
1149         continue;
1150       }
1151       if (check_live_range_ordering) {
1152         // Verify values contained in the buffer are strictly ordered. This
1153         // should always be the case after adding copies to eliminate
1154         // interference. Specifically, the addition of the control flow edges
1155         // between copies added around aliased operations (kWhile) guarantees
1156         // this strict order.
1157         for (const HloValue* value_a : buffer.values()) {
1158           if (value_a->shape().IsToken()) {
1159             // Token values have no representation and cannot interfere.
1160             continue;
1161           }
1162           for (const HloValue* value_b : buffer.values()) {
1163             if (value_a != value_b) {
1164               DCHECK(ordering_->LiveRangeStrictlyBefore(
1165                          *value_a, *value_b, dataflow_,
1166                          /*use_is_always_before_def_in_same_instr=*/true) ||
1167                      ordering_->LiveRangeStrictlyBefore(
1168                          *value_b, *value_a, dataflow_,
1169                          /*use_is_always_before_def_in_same_instr=*/true))
1170                   << value_a->ToString() << " and " << value_b->ToString()
1171                   << " are not ordered";
1172             }
1173           }
1174         }
1175       }
1176 
1177       std::vector<const HloValue*> values = buffer.values();
1178       absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
1179         return ordering_->IsDefinedBefore(*a, *b);
1180       });
1181 
1182       // Create a list containing all of the values in the buffer.
1183       AddValueList(values, &value_to_node);
1184     }
1185 
1186     // Create copy_map_ which contains the source and destination values
1187     // of all copies.
1188     CreateCopyMap(module, value_to_node);
1189 
1190     XLA_VLOG_LINES(3, ToString());
1191     TF_DCHECK_OK(Verify());
1192   }
1193 
1194   // Add a list containing the given values to CopyRemover. This
1195   // represents the values contained in a single buffer. For each value in
1196   // 'values' an entry is created in value_to_node which indicates the
1197   // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)1198   void AddValueList(
1199       absl::Span<const HloValue* const> values,
1200       absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
1201     ValueNode* tail = nullptr;
1202     ValueNode* head = nullptr;
1203     for (const HloValue* value : values) {
1204       auto new_node = new ValueNode(value);
1205       (*value_to_node)[value] = new_node;
1206 
1207       // Copy the HLO values's uses into the ValueNode for the value. These
1208       // uses in ValueNode are updated as copies are removed.
1209       new_node->uses.reserve(value->uses().size());
1210       for (const HloUse& use : value->uses()) {
1211         new_node->uses.push_back(&use);
1212       }
1213 
1214       // Connect the new node into the linked list.
1215       if (tail == nullptr) {
1216         head = new_node;
1217       } else {
1218         tail->next = new_node;
1219         new_node->prev = tail;
1220       }
1221       tail = new_node;
1222     }
1223 
1224     // The linked list is circular so connect the head and tail.
1225     tail->next = head;
1226     head->prev = tail;
1227     value_lists_.insert(head);
1228   }
1229 
1230   // This method also fills in copy_map_ which indicates which nodes
1231   // in the value lists corresponding to the source and destination values of
1232   // kCopy instructions. value_to_node should map each HloValue to its
1233   // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)1234   void CreateCopyMap(
1235       const HloModule& module,
1236       const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
1237     for (HloComputation* computation : module.MakeNonfusionComputations()) {
1238       for (HloInstruction* instruction : computation->instructions()) {
1239         // Add copies with unambiguous source values to the map. Copies with
1240         // ambiguous sources are not removable.
1241         if (instruction->opcode() == HloOpcode::kCopy) {
1242           const HloValueSet& src_value_set =
1243               dataflow_.GetValueSet(instruction->operand(0));
1244           if (src_value_set.values().size() == 1) {
1245             CopyNodes& copy_node = copy_map_[instruction];
1246             copy_node.dest =
1247                 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
1248             copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
1249           }
1250         }
1251       }
1252     }
1253   }
1254 
~CopyRemover()1255   ~CopyRemover() {
1256     for (const ValueNode* head : value_lists_) {
1257       const ValueNode* p = head;
1258       do {
1259         const ValueNode* tmp = p->next;
1260         delete p;
1261         p = tmp;
1262       } while (p != head);
1263     }
1264   }
1265 
1266   // Verify invariants within the linked lists.
Verify() const1267   Status Verify() const {
1268     for (const ValueNode* head : value_lists_) {
1269       const ValueNode* p = head;
1270       do {
1271         // Verify links between elements are consistent.
1272         TF_RET_CHECK(p->prev->next == p);
1273         TF_RET_CHECK(p->next->prev == p);
1274 
1275         const HloInstruction* def = p->value->defining_instruction();
1276         if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
1277           TF_RET_CHECK(copy_map_.at(def).dest == p);
1278         }
1279         for (const HloUse* use : p->uses) {
1280           if (use->instruction->opcode() == HloOpcode::kCopy &&
1281               ContainsKey(copy_map_, use->instruction)) {
1282             TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
1283           }
1284         }
1285 
1286         p = p->next;
1287       } while (p != head);
1288     }
1289     return Status::OK();
1290   }
1291 
1292   // Compute the set of instructions where values are alive and organize these
1293   // instructions by separating them into their respective computations.
ComputeLiveRangeRegions(const ValueNode * head)1294   LiveRangeRegions ComputeLiveRangeRegions(const ValueNode* head) {
1295     LiveRangeRegions live_range;
1296 
1297     auto VisitValueNode = [&](const ValueNode* node) {
1298       HloInstruction* def_op = node->value->instruction();
1299       HloComputation* def_parent = def_op->parent();
1300       live_range[def_parent][def_op].is_definition = true;
1301       for (const auto& use : node->uses) {
1302         auto* use_op = use->instruction;
1303         HloComputation* use_parent = use_op->parent();
1304         live_range[use_parent][use_op].value_definition = def_op;
1305       }
1306     };
1307     ForEachValueInRange(head, VisitValueNode);
1308     return live_range;
1309   }
1310 
1311   // Try to elide the given copy. Elision of a copy is possible only if no
1312   // live range interference is introduced by the copy's elimination. If
1313   // elision is possible, then the internal state (value lists) are updated,
1314   // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy,bool use_region_analysis)1315   bool TryElideCopy(const HloInstruction* copy, bool use_region_analysis) {
1316     VLOG(2) << "Trying to remove " << copy->name();
1317 
1318     if (!ContainsKey(copy_map_, copy)) {
1319       VLOG(2) << copy->name() << " is not removable";
1320       return false;
1321     }
1322     if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
1323       VLOG(2) << copy->name() << " is not removable (shape mismatch)";
1324       return false;
1325     }
1326     const CopyNodes& copy_node = copy_map_.at(copy);
1327     DCHECK(copy_node.src != nullptr);
1328     DCHECK(copy_node.dest != nullptr);
1329 
1330     VLOG(3) << copy->name() << " copies value "
1331             << copy_node.src->value->ToShortString();
1332     VLOG(3) << "Source buffer values: " << ValueListToString(copy_node.src);
1333     VLOG(3) << "Dest buffer values: " << ValueListToString(copy_node.dest);
1334     // Checks whether the live range at src is before that defined by dest.
1335     auto CheckLiveRangeBefore = [&](ValueNode* src, ValueNode* dest) {
1336       for (ValueNode* next_dest = dest; next_dest != nullptr;
1337            next_dest = Next(*next_dest)) {
1338         for (ValueNode* prev_src = src; prev_src != nullptr;
1339              prev_src = Prev(*prev_src)) {
1340           if (!LiveRangeBefore(*prev_src, *next_dest)) {
1341             VLOG(2) << "Live range of " << prev_src->value->ToShortString()
1342                     << " is not before " << next_dest->value->ToShortString();
1343             return false;
1344           }
1345         }
1346       }
1347       return true;
1348     };
1349     auto CheckLiveRangeInterference = [&](ValueNode* src, ValueNode* dest,
1350                                           const CombineLiveRangeOption option) {
1351       CHECK_NE(src, nullptr);
1352       CHECK_NE(dest, nullptr);
1353       if (!use_region_analysis) {
1354         VLOG(2) << "Configured to not use region-based analysis.\n";
1355         return true;
1356       }
1357       if (ValuesInterfere(src, dest, option)) {
1358         VLOG(2) << "Region-based interference is true. \n";
1359         return true;
1360       }
1361       VLOG(2) << "Region-based interference is false. \n";
1362       return false;
1363     };
1364 
1365     // A kCopy instruction copies an HLO value from a source buffer and
1366     // defines an HLO value in a destination buffer. Most generally, the
1367     // source and destination buffers may each hold more than one value at
1368     // different points in the computation so we define the following:
1369     //
1370     //   Values in source buffer:      {s_0, ..., s_n}
1371     //   Values in destination buffer: {d_0, ..., d_m}
1372     //
1373     // A kCopy instruction between these buffers copies a value s_x in the
1374     // source buffer and defines a value d_y in the destination buffer. The
1375     // elision of a copy merges the source and destination buffers together,
1376     // so the list of values for the source and destination buffers are
1377     // merged.
1378     //
1379     // We handle two different cases for copy elision:
1380     //
1381     //  (1) the kCopy defines the first value in the destination buffer (d_0).
1382     //
1383     //  (2) the kCopy copies the last value in the source buffer (s_n).
1384     //
1385     // For the remaining case where the kCopy copies a not-last value from the
1386     // source buffer to a not-first value of the destination buffer, the kCopy
1387     // instruction cannot be removed. This case is generated, for example, if
1388     // the kCopy copies a while body parameter of the loop state at one tuple
1389     // index to a different tuple index in the while body root. Removal of the
1390     // copy necessarily results in live range interference of values in the
1391     // loop state at the two different tuple indices.
1392     //
1393     //  We can only perform copy elision if the resulting merged values have
1394     //  totally ordered live ranges; otherwise the merged buffer would have
1395     //  live range interference.
1396     if (copy_node.src->next == copy_node.dest) {
1397       // In the process of eliding copies, its possible for a copy to have the
1398       // same source and destination buffer. In this case, the copy can be
1399       // safely removed.
1400       VLOG(2) << copy->name() << " source and destination buffers are same.";
1401     } else if (IsHead(*copy_node.dest)) {
1402       // The copy copies an arbitrary value in the source buffer (call it s_x)
1403       // and defines d_0, the first value in the destination buffer. After
1404       // merging, the values in the combined buffer must be strictly ordered
1405       // as follows** to elide the copy:
1406       //
1407       // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
1408       //
1409       // Removing the copy eliminates d_0, and uses of d_0 become uses of
1410       // s_x. In the above ordering, the live range of d_m will be ordered
1411       // before the live range of s_{x+1} and the definition and all uses of
1412       // s_x will be ordered before the definition of d_1. To make sure the
1413       // copy elision is safe, the following code checks that this ordering is
1414       // valid --- in particular we check it is safe to order d_m ahead of all
1415       // the liverages at and after x_{x+1}, and it is safe to order all uses
1416       // of s_x before the definition of d_1, by checking the live range
1417       // constraints for each pair --- we cannot skip the later checks because
1418       // the live range ordering is not guranteed to be transitive --- while it
1419       // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
1420       // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
1421       // because the exclusiveness relation of non-overlapping computations is
1422       // not transitive.
1423       //
1424       // ** Technically it might be possible to have a non-interfering
1425       //    non-trivial interleaving of the values of the source and
1426       //    destination buffers in the resulting order. This can be potentially
1427       //    supported in the ValuesInterfere function, which performs
1428       //    interference analysis at a more global scope than the alternative
1429       //    LiveRangeBefore analysis which requires strict ordering of all live
1430       //    ranges. Currently, however, this is not yet supported, as
1431       //    we simply check for the case where *all* values of the destination
1432       //    buffer (d_1 through d_m) are spliced into the point where the copy
1433       //    used to be.
1434       VLOG(2) << copy->name() << " defines the first value in its buffer";
1435       bool live_range_before =
1436           // Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
1437           CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)) &&
1438           // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
1439           CheckLiveRangeBefore(copy_node.dest->prev, Next(*copy_node.src));
1440       VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1441       if (!live_range_before &&
1442           CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1443                                      kMergeFirstDestInSource)) {
1444         return false;
1445       }
1446       VLOG(2) << "Splice dest after source.";
1447       // Splice in destination buffer values list right after 'src'.
1448       SpliceAfter(copy_node.dest, copy_node.src);
1449     } else if (IsTail(*copy_node.src)) {
1450       // The copy copies the last value in the source buffer, s_n, and defines
1451       // an arbitrary value in the destination buffer, d_y.  After
1452       // merging, the values in the combined buffer must be strictly ordered
1453       // as follows** to elide the copy:
1454       //
1455       // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
1456       //
1457       // Removing the copy eliminates d_y, and uses of d_y become uses of
1458       // s_n. To enforce the above order, the live range of d_{y-1} must be
1459       // before the live range of s_0, and the live range of s_n must be
1460       // before the live range of d_{y+1}.
1461       //
1462       // ** See comment above in the code handling Case (1).
1463       VLOG(2) << copy->name() << " copies the last value ("
1464               << copy_node.src->value->ToShortString() << ") in its buffer";
1465       bool live_range_before =
1466           // Live range of d_0, ..., d_{y-1} must be before s_0;
1467           CheckLiveRangeBefore(Prev(*copy_node.dest), copy_node.src->next) &&
1468           // Live range of 'last_src' must be before next_dest d_{y+1}.
1469           CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest));
1470       VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1471       if (!live_range_before &&
1472           CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1473                                      kMergeLastSourceInDest)) {
1474         VLOG(2) << "Region-based analysis concludes interference.\n";
1475         return false;
1476       }
1477       VLOG(2) << "Splice src after prev of dest.";
1478       // Splice source buffer values list right after 'prev_dest'.
1479       SpliceAfter(copy_node.src->next, Prev(*copy_node.dest));
1480     } else {
1481       VLOG(2) << copy->name()
1482               << " copies value in middle of source buffer to value in middle "
1483                  "of destination buffer";
1484       return false;
1485     }
1486 
1487     RemoveCopyValue(copy_node.dest);
1488 
1489     XLA_VLOG_LINES(4, ToString());
1490     TF_DCHECK_OK(Verify());
1491 
1492     return true;
1493   }
1494 
1495   // Delete the given ValueNode associated with a elided kCopy
1496   // instruction. This should be called after splicing the value lists of the
1497   // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)1498   void RemoveCopyValue(ValueNode* copy_value_node) {
1499     CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
1500              HloOpcode::kCopy);
1501     ValueNode* operand_node = copy_value_node->prev;
1502     CHECK(operand_node != copy_value_node);
1503 
1504     VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
1505             << " => " << copy_value_node->value->ToShortString();
1506 
1507     // Splice out the copy value node.
1508     operand_node->next = copy_value_node->next;
1509     copy_value_node->next->prev = operand_node;
1510 
1511     // Patch up uses. Remove use of copy from operand_node uses.
1512     auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
1513                                                       const HloUse* use) {
1514       return use->instruction == copy_value_node->value->defining_instruction();
1515     });
1516     CHECK(it != operand_node->uses.end());
1517     operand_node->uses.erase(it);
1518 
1519     // If the elided copy has any uses which are themselves kCopy instructions
1520     // then patch up the copy info to reflect the that this kCopy instruction
1521     // has a different operand (the operand of the elided copy).
1522     for (const HloUse* copy_use : copy_value_node->uses) {
1523       operand_node->uses.push_back(copy_use);
1524       if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
1525           ContainsKey(copy_map_, copy_use->instruction)) {
1526         copy_map_.at(copy_use->instruction).src = operand_node;
1527       }
1528     }
1529 
1530     // Delete the copy info and the value node.
1531     copy_map_.erase(copy_value_node->value->defining_instruction());
1532     delete copy_value_node;
1533   }
1534 
1535   // Returns true if the live range of given value 'a' is before the live
1536   // range of 'b'.
1537   //
1538   // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
1539   // updated as copies are removed. Also here because the result is used
1540   // to directly drive copy elision, use_is_always_before_def_in_same_instr is
1541   // set to false.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)1542   bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
1543     if (a.uses.empty()) {
1544       VLOG(2) << "Empty uses for " << *a.value;
1545       return ordering_->IsDefinedBefore(*a.value, *b.value);
1546     }
1547     VLOG(3) << "Checking live ranges before :" << ValueListToString(&a)
1548             << " vs " << ValueListToString(&b) << "\n";
1549     return ordering_->UsesBeforeValueDefinition(
1550         a.uses, *b.value, dataflow_,
1551         /* use_is_always_before_def_in_same_instr=*/false);
1552   }
1553 
1554   // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const1555   bool IsTail(const ValueNode& node) const {
1556     return ContainsKey(value_lists_, node.next);
1557   }
1558 
1559   // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const1560   bool IsHead(const ValueNode& node) const {
1561     return ContainsKey(value_lists_, &node);
1562   }
1563 
1564   // Returns the next node in the list after 'node'. If 'node' is the
1565   // tail, then nullptr is returned.
Next(const ValueNode & node) const1566   ValueNode* Next(const ValueNode& node) const {
1567     if (IsTail(node)) {
1568       return nullptr;
1569     } else {
1570       return node.next;
1571     }
1572   }
1573 
1574   // Returns the previous node in the list before 'node'. If 'node'
1575   // is the head, then nullptr is returned.
Prev(const ValueNode & node) const1576   ValueNode* Prev(const ValueNode& node) const {
1577     if (IsHead(node)) {
1578       return nullptr;
1579     } else {
1580       return node.prev;
1581     }
1582   }
1583 
1584   // Splices the entire linked list with 'head' as its head right after the
1585   // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)1586   void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
1587     DCHECK(IsHead(*head));
1588     value_lists_.erase(head);
1589 
1590     ValueNode* tail = head->prev;
1591     tail->next = insert_after->next;
1592     insert_after->next->prev = tail;
1593 
1594     insert_after->next = head;
1595     head->prev = insert_after;
1596   }
1597 
1598   enum CombineLiveRangeOption {
1599     kMergeFirstDestInSource = 1,
1600     kMergeLastSourceInDest = 2
1601   };
1602   // This function analyzes all the HloValues that have been grouped together
1603   // with src to share a single buffer, and all the HloValues that have been
1604   // similarly grouped together with dest, to determine whether these two groups
1605   // can be combined, by removing the operation in dest, which makes a copy of
1606   // the buffer in src.
ValuesInterfere(const ValueNode * src,const ValueNode * dest,CombineLiveRangeOption merge_location)1607   bool ValuesInterfere(const ValueNode* src, const ValueNode* dest,
1608                        CombineLiveRangeOption merge_location) {
1609     // Get the entire range of values sharing the buffers in src and dest.
1610     auto src_live_range = ComputeLiveRangeRegions(src);
1611     auto dest_live_range = ComputeLiveRangeRegions(dest);
1612     ComputeRelativeLocation relative_location_analysis(ordering_);
1613     auto rel1 =
1614         relative_location_analysis.Compute(src_live_range, dest_live_range);
1615     VLOG(3) << "Location of dest in relation to src:" << rel1.ToString()
1616             << " with interception set to " << rel1.InterceptDefUse() << "\n";
1617     auto rel2 =
1618         relative_location_analysis.Compute(dest_live_range, src_live_range);
1619     VLOG(3) << "Location of src in relation to dest:" << rel2.ToString()
1620             << " with interception set to " << rel1.InterceptDefUse() << "\n";
1621     // If src and dest are interleaved with each other, they interfere.
1622     if (rel1.RuntimeOrderOverlap() && rel2.RuntimeOrderOverlap()) {
1623       VLOG(3) << "Both relations are overlap.\n";
1624       return true;
1625     }
1626     // If src and dest belong to the same group of computations and do not
1627     // overlap, they do not interfere.
1628     if (rel1.RuntimeOrderOverlap() || rel2.RuntimeOrderOverlap()) {
1629       VLOG(3) << "At least one relation is overlap.\n";
1630       if (rel1.RuntimeOrderOverlap()) {
1631         VLOG(3) << "rel1 is overlap, with interception = "
1632                 << rel1.InterceptDefUse() << "\n";
1633         if (rel1.InterceptDefUse() ||
1634             (merge_location != kMergeFirstDestInSource &&
1635              rel2.InterceptDefUse())) {
1636           return true;
1637         }
1638       } else {
1639         VLOG(3) << "rel2 is overlap, with interception = "
1640                 << rel2.InterceptDefUse() << "\n";
1641         // Here src is at the end of a nested computation inside dest.
1642         if (rel2.InterceptDefUse() ||
1643             (merge_location != kMergeLastSourceInDest &&
1644              rel1.InterceptDefUse())) {
1645           return true;
1646         }
1647       }
1648     }
1649     if (relative_location_analysis.AddControlDependenceForUnorderedOps()) {
1650       return false;
1651     } else {
1652       // Disallow removing of copy if control deps cannot be added.
1653       return true;
1654     }
1655   }
1656 
1657   // return the sequence of HloValues starting from element.
1658   // If element is not head, traverse from element to tail, then wrap around.
1659   // The ordering is important for live range region analysis.
ForEachValueInRange(const ValueNode * element,std::function<void (const ValueNode *)> visitor)1660   void ForEachValueInRange(const ValueNode* element,
1661                            std::function<void(const ValueNode*)> visitor) {
1662     const ValueNode* head = element;
1663     std::vector<const ValueNode*> values;
1664     for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
1665       visitor(p);
1666     }
1667     while (!IsHead(*head)) {
1668       head = Prev(*head);
1669     }
1670     for (const ValueNode* p = head; p != element; p = Next(*p)) {
1671       visitor(p);
1672     }
1673   }
1674 
ValueListToString(const ValueNode * element)1675   string ValueListToString(const ValueNode* element) {
1676     std::string result = "{";
1677     auto VisitValueNode = [&](const ValueNode* node) {
1678       if (result == "{") {
1679         result = node->value->ToShortString();
1680       } else {
1681         StrAppend(&result, ", ");
1682         StrAppend(&result, node->value->ToShortString());
1683       }
1684     };
1685     VisitValueNode(element);
1686     StrAppend(&result, "}");
1687     return result;
1688   }
1689 
ToString() const1690   string ToString() const {
1691     string out = absl::StrCat("CopyRemover:\n");
1692     StrAppend(&out, "  Def-use chains in each buffer:\n");
1693     for (const ValueNode* head : value_lists_) {
1694       StrAppend(&out, "    Buffer defined by ", head->value->ToShortString(),
1695                 ":\n");
1696       const ValueNode* p = head;
1697       do {
1698         StrAppend(&out, "      ", p->value->ToShortString(), ", uses: ",
1699                   absl::StrJoin(p->uses, "; ",
1700                                 [](string* s, const HloUse* use) {
1701                                   StrAppend(s, use->ToString());
1702                                 }),
1703                   "\n");
1704 
1705         p = p->next;
1706       } while (p != head);
1707     }
1708     StrAppend(&out, "  Potentially removable copies:\n");
1709     for (const auto& pair : copy_map_) {
1710       const HloInstruction* copy = pair.first;
1711       const CopyNodes& copy_info = pair.second;
1712 
1713       StrAppend(&out, "    ", copy->name(), " : ",
1714                 copy_info.src->value->ToShortString(), " => ",
1715                 copy_info.dest->value->ToShortString(), "\n");
1716     }
1717     return out;
1718   }
1719 
1720  private:
1721   const HloDataflowAnalysis& dataflow_;
1722   HloOrdering* ordering_;
1723 
1724   // The heads of all the value lists. Each value list represents the HLO
1725   // values contained in a particular HLO buffer. The values in the list are
1726   // in dependency order.
1727   absl::flat_hash_set<const ValueNode*> value_lists_;
1728 
1729   // Copy removal requires fast access to the value list elements
1730   // corresponding to the source and destination values of the kCopy
1731   // instruction. This data structure holds pointers to these elements for
1732   // each kCopy instruction in the graph.
1733   struct CopyNodes {
1734     // The source and destinations values of the kCopy instruction.
1735     ValueNode* src = nullptr;
1736     ValueNode* dest = nullptr;
1737   };
1738   absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
1739 };
1740 
1741 }  // namespace
1742 
1743 // We add copies for all non-phi indices of the true and false computation
1744 // roots, in order to resolve interference. We later rely on
1745 // RemoveUnnecessaryCopies to drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)1746 Status CopyInsertion::AddCopiesForConditional(
1747     const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
1748   VLOG(2) << "Adding copies for kConditional instruction "
1749           << conditional->name();
1750   ShapeTree<bool> indices_to_copy(conditional->shape());
1751   TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
1752   if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
1753                                    conditional, &indices_to_copy)) {
1754     VLOG(2) << "No copies necessary for kWhile instruction "
1755             << conditional->name();
1756     return Status::OK();
1757   }
1758 
1759   for (HloComputation* computation : conditional->branch_computations()) {
1760     HloInstruction* root = computation->root_instruction();
1761     std::vector<HloInstruction*> users = root->users();
1762     TF_ASSIGN_OR_RETURN(
1763         HloInstruction * deep_copy,
1764         computation->DeepCopyInstruction(root, &indices_to_copy));
1765     for (HloInstruction* user : users) {
1766       TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
1767     }
1768     computation->set_root_instruction(deep_copy);
1769   }
1770   return Status::OK();
1771 }
1772 
1773 // Add kCopy instructions to the given module to guarantee there is no
1774 // live-range interference. Generally interference can only occur around kWhile
1775 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module)1776 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
1777   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1778                       HloAliasAnalysis::Run(module, can_share_buffer_));
1779 
1780   for (HloComputation* computation : module->MakeNonfusionComputations()) {
1781     for (HloInstruction* instruction :
1782          computation->MakeInstructionPostOrder()) {
1783       if (instruction->opcode() == HloOpcode::kWhile) {
1784         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
1785       } else if (instruction->opcode() == HloOpcode::kConditional) {
1786         TF_RETURN_IF_ERROR(
1787             AddCopiesForConditional(*alias_analysis, instruction));
1788       } else {
1789         // When an operand is a tuple, we avoid copying the operand multiple
1790         // times by recording and checking the operand number of operands that
1791         // have been copied.
1792         absl::flat_hash_set<int64> copied_operands;
1793         for (const auto& operand_and_output_index :
1794              HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
1795           const HloUse& operand = operand_and_output_index.first;
1796           if (copied_operands.contains(operand.operand_number)) {
1797             continue;
1798           }
1799           copied_operands.insert(operand.operand_number);
1800           TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
1801               *alias_analysis, instruction, operand.operand_number));
1802         }
1803       }
1804     }
1805   }
1806 
1807   TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
1808   return Status::OK();
1809 }
1810 
AddSpecialCaseCopies(HloModule * module)1811 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
1812   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1813   return AddSpecialCaseCopies(*call_graph, module);
1814 }
1815 
AddSpecialCaseCopies(const CallGraph & call_graph,HloModule * module)1816 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
1817                                            HloModule* module) {
1818   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1819                       HloAliasAnalysis::Run(module, can_share_buffer_));
1820 
1821   // Identify which shape indices of which instructions need to be copied. Store
1822   // these results in 'instructions_to_copy'.
1823   HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
1824   auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
1825                                                    const ShapeIndex& index) {
1826     auto it = instructions_to_copy.find(instruction);
1827     if (it == instructions_to_copy.end()) {
1828       auto it_added = instructions_to_copy.emplace(
1829           std::piecewise_construct, std::forward_as_tuple(instruction),
1830           std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1831       it = it_added.first;
1832     }
1833     *it->second.mutable_element(index) = true;
1834   };
1835 
1836   // Iterate through values of all constants and entry parameters. These values
1837   // are special because they are held in read-only buffers. If any of these
1838   // values share a buffer with other values (for example, the init value of a
1839   // while is a constant) then copy the value at its definition and replace all
1840   // its uses with the copy.
1841   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1842     if (ValueIsReadOnly(*value) &&
1843         alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
1844       VLOG(2) << "Value " << value->ToShortString()
1845               << " is read only, but its buffer contains more than one value. "
1846                  "Copying.";
1847       add_index_to_copy(value->defining_instruction(), value->defining_index());
1848     }
1849   }
1850 
1851   // Identify copies which must be added at root instructions
1852   for (HloComputation* computation : module->computations()) {
1853     const CallGraphNode& node = call_graph.GetNode(computation);
1854     if (node.context() == CallContext::kParallel) {
1855       continue;
1856     }
1857     TF_RET_CHECK(node.context() == CallContext::kSequential);
1858 
1859     SpecialCaseCopyPolicy policy =
1860         GetSpecialCaseCopyPolicy(node, module, computation);
1861     HloInstruction* root = computation->root_instruction();
1862 
1863     // Mark nondistinct/ambiguous indices.
1864     absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
1865     ShapeUtil::ForEachSubshape(
1866         root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1867           std::vector<const HloBuffer*> buffers_at_index =
1868               alias_analysis->ComputeBuffersAt(root, index);
1869           bool buffer_seen_before = false;
1870           for (const HloBuffer* buffer : buffers_at_index) {
1871             buffer_seen_before |= !seen.emplace(buffer, index).second;
1872           }
1873 
1874           if (buffer_seen_before && policy.copy_root_replicated_buffers &&
1875               computation == module->entry_computation() &&
1876               module->input_output_alias_config().OutputHasAlias(index) &&
1877               buffers_at_index.size() == 1) {
1878             absl::optional<HloInputOutputAliasConfig::Alias> alias =
1879                 module->input_output_alias_config().GetAliasedParameter(index);
1880             CHECK(alias) << "Alias does not exist";
1881             const ShapeIndex& other_index = seen[buffers_at_index[0]];
1882             VLOG(2) << "Output indices " << index.ToString() << " and "
1883                     << other_index.ToString() << " are both aliased to "
1884                     << alias->parameter_number << " copying " << other_index;
1885             add_index_to_copy(root, other_index);
1886             return;
1887           }
1888 
1889           if (buffers_at_index.size() > 1 ||
1890               (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1891             VLOG(2) << "Index " << index << " of computation "
1892                     << computation->name() << " (" << root->name()
1893                     << ") has ambiguous or non-distinct buffer. Copying.";
1894             add_index_to_copy(root, index);
1895           }
1896         });
1897 
1898     for (const auto& pair :
1899          alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1900       const ShapeIndex& index = pair.first;
1901       const HloValueSet& value_set = pair.second;
1902       for (const HloValue* value : value_set.values()) {
1903         if (ShouldCopyRootValue(*value, policy)) {
1904           VLOG(2) << "Root of (" << root->name() << ") of computation("
1905                   << computation->name()
1906                   << ") has constant or parameter value at index " << index
1907                   << ". Copying.";
1908           add_index_to_copy(root, index);
1909         }
1910       }
1911     }
1912   }
1913 
1914   // Add copy instructions indicated in 'instructions_to_copy' to the module.
1915   for (const auto& pair : instructions_to_copy) {
1916     HloInstruction* instruction = pair.first;
1917     const ShapeTree<bool>& indices_to_copy = pair.second;
1918 
1919     ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1920     std::vector<HloInstruction*> users = instruction->users();
1921     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1922                         instruction->parent()->DeepCopyInstruction(
1923                             instruction, &indices_to_copy, &copies_added));
1924     for (HloInstruction* user : users) {
1925       TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
1926     }
1927     if (instruction == instruction->parent()->root_instruction()) {
1928       instruction->parent()->set_root_instruction(deep_copy);
1929     }
1930   }
1931   return Status::OK();
1932 }
1933 
GetNumExistingCopies(const HloModule * module)1934 static int64 GetNumExistingCopies(const HloModule* module) {
1935   int64_t num_existing_copies = 0;
1936   for (HloComputation* computation : module->computations()) {
1937     for (HloInstruction* instruction : computation->instructions()) {
1938       if (instruction->opcode() == HloOpcode::kCopy) {
1939         ++num_existing_copies;
1940       }
1941     }
1942   }
1943   return num_existing_copies;
1944 }
1945 
RemoveUnnecessaryCopies(HloOrdering * ordering,HloModule * module,bool check_live_range_ordering)1946 Status CopyInsertion::RemoveUnnecessaryCopies(HloOrdering* ordering,
1947                                               HloModule* module,
1948                                               bool check_live_range_ordering) {
1949   XLA_VLOG_LINES(4, module->ToString());
1950   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1951                       HloAliasAnalysis::Run(module, can_share_buffer_));
1952 
1953   CopyRemover copy_remover(*module, *alias_analysis, ordering,
1954                            check_live_range_ordering);
1955   if (VLOG_IS_ON(3)) {
1956     LOG(INFO) << "Removing unnecessary copies in " << module->name();
1957     LOG(INFO) << "Buffer values, in dependency order: ";
1958     for (const HloBuffer& buffer : alias_analysis->buffers()) {
1959       LOG(INFO) << "    HloBuffer " << buffer.id();
1960     }
1961   }
1962 
1963   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1964 
1965   int64_t num_existing_copies = GetNumExistingCopies(module);
1966   bool changed = true;
1967   int64_t num_iterations = -1;
1968   while (changed) {
1969     CHECK_LE(++num_iterations, num_existing_copies);
1970     changed = false;
1971     VLOG(2) << "Running fixpoint iteration " << num_iterations
1972             << " of copy elision";
1973     for (HloComputation* computation : module->computations()) {
1974       VLOG(2) << "computation:" << computation->name() << "\n";
1975       for (HloInstruction* instruction : computation->instructions()) {
1976         VLOG(2) << instruction->ToString() << "\n";
1977         if (instruction->opcode() == HloOpcode::kCopy &&
1978             copy_remover.TryElideCopy(instruction,
1979                                       use_region_based_live_range_analysis_)) {
1980           changed = true;
1981           TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
1982           TF_RETURN_IF_ERROR(
1983               instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
1984         }
1985       }
1986     }
1987   }
1988   return Status::OK();
1989 }
1990 
Run(HloModule * module)1991 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
1992   // Copy insertion is performed in three steps:
1993   //
1994   // (1) Add copies conservatively to guarantee that there is no live-range
1995   //     interference. This is done simplistically and usually results in more
1996   //     copies than is strictly necessary.
1997   //
1998   // (2) Using a more fine-grained analysis, remove as many copies that were
1999   //     added in (1) as possible while ensuring no live-range interference.
2000   //
2001   // (3) Add copies to resolve issues not related to live range interference
2002   //     such as parameters and constants live out of the entry computation.
2003   //
2004   // We add copies then remove them (step (1) then (2)) rather than simply
2005   // adding only the copies that are necessary because, in general, it is
2006   // difficult to figure out the minimal set of copies to add once there is
2007   // interference. On the other hand, it is easy to determine if removing a copy
2008   // will introduce interference.
2009   //
2010   // The final copy insertion in (3) is done separately to simplify the
2011   // implementation of copy removal in (2) which is the most complicated part of
2012   // the pass. As is, copy removal only has to reason about live range
2013   // interference. If all copies were added in step (1) then copy removal would
2014   // also have to reason about things like constants and parameters live out of
2015   // the computation.
2016   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
2017   if (!call_graph->IsFlattened()) {
2018     return FailedPrecondition(
2019         "Call graph must be flattened before copy insertion.");
2020   }
2021 
2022   TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
2023 
2024   // Simplify the tuple structures introduced by the deep copies. This should be
2025   // done before removing copies (RemoveUnnecessaryCopies) because tuple
2026   // simplification changes dependencies in the graph which changes live range
2027   // interference in the graph. Also run DCE to remove the dead Tuple/GTE
2028   // instructions introduced by tuple simplification.
2029   TupleSimplifier tuple_simplifier;
2030   HloDCE dce;
2031   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2032   TF_RETURN_IF_ERROR(dce.Run(module).status());
2033   DumpHloModuleDuringPassIfEnabled(
2034       name(), "after adding copies to resolve interference", *module);
2035 
2036   DependencyHloOrdering ordering(module);
2037   TF_RETURN_IF_ERROR(
2038       RemoveUnnecessaryCopies(&ordering, module,
2039                               /*check_live_range_ordering=*/true));
2040   DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
2041                                    *module);
2042   TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
2043   DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
2044                                    *module);
2045 
2046   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2047   TF_RETURN_IF_ERROR(dce.Run(module).status());
2048 
2049   if (VLOG_IS_ON(1)) {
2050     int64_t num_total_copies = 0;
2051     for (HloComputation* computation : module->computations()) {
2052       for (HloInstruction* instruction : computation->instructions()) {
2053         if (instruction->opcode() == HloOpcode::kCopy) {
2054           num_total_copies++;
2055         }
2056       }
2057     }
2058     VLOG(1) << "Num copies before copy-insertion: "
2059             << GetNumExistingCopies(module);
2060     VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
2061   }
2062 
2063   return true;
2064 }
2065 }  // namespace xla
2066