1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include <optional>
19 #include <sstream>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/any.h"
27 #include "tensorflow/compiler/xla/service/dump.h"
28 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_dce.h"
31 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
36 #include "tensorflow/compiler/xla/service/logical_buffer.h"
37 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/platform/logging.h"
43
44 namespace xla {
45 namespace {
46
47 using absl::StrAppend;
48
IsReadonlyEntryParameterValue(const HloValue & value)49 bool IsReadonlyEntryParameterValue(const HloValue& value) {
50 const HloComputation* computation = value.defining_instruction()->parent();
51 return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
52 computation == computation->parent()->entry_computation() &&
53 !computation->parent()->input_output_alias_config().ParameterHasAlias(
54 value.defining_instruction()->parameter_number(), value.index());
55 }
56
IsConstantValue(const HloValue & value)57 bool IsConstantValue(const HloValue& value) {
58 return value.defining_instruction()->opcode() == HloOpcode::kConstant;
59 }
60
ValueIsReadOnly(const HloValue & value)61 bool ValueIsReadOnly(const HloValue& value) {
62 return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
63 }
64
65 // Data structure describing the action which should be taken on parts of a
66 // computation buffers, with respect to the adding of special case copies.
67 struct SpecialCaseCopyPolicy {
68 // Insert a copy if the same buffer is found at multiple indices within the
69 // output tuple.
70 bool copy_root_replicated_buffers = false;
71 // If true, insert a copy if a buffer coming from a constant or a parameter
72 // is found within the output tuple.
73 bool copy_parameters_and_constants = false;
74 };
75
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)76 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
77 HloModule* module,
78 HloComputation* computation) {
79 SpecialCaseCopyPolicy policy;
80 if (computation == module->entry_computation()) {
81 policy.copy_parameters_and_constants = true;
82 policy.copy_root_replicated_buffers = true;
83 }
84 return policy;
85 }
86
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)87 bool ShouldCopyRootValue(const HloValue& value,
88 const SpecialCaseCopyPolicy& policy) {
89 if (policy.copy_parameters_and_constants) {
90 return ValueIsReadOnly(value);
91 }
92 return false;
93 }
94
95 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
96 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
97 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
98 // of 'to'.
99 //
100 // Requirements: 'from' and 'to' must have compatible shapes.
101 //
102 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
103 // the only index to copy. Prior to deep-copying we have:
104 //
105 //
106 // 'from'
107 // |
108 // ...
109 // |
110 // 'to'
111 //
112 // DeepCopyAndAddControlEdges produces:
113 //
114 // 'from'
115 // / \
116 // GTE GTE
117 // | |
118 // Copy |
119 // / \ /
120 // | Tuple
121 // | |
122 // ctrl ...
123 // edge |
124 // | |
125 // | 'to'
126 // | / \
127 // | GTE GTE
128 // \ | |
129 // Copy |
130 // \ /
131 // Tuple
132 //
133 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)134 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
135 const ShapeTree<bool>& indices_to_copy) {
136 DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
137 // to/from_copy_tree hold the kCopy instruction produces by the deep
138 // copies. Elements which are not copied (indices_to_copy.element(index) ==
139 // false) have nullptr at that index.
140 ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
141 /*init_value=*/nullptr);
142 TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
143 from->parent()->DeepCopyInstruction(
144 from, &indices_to_copy, &from_copy_tree));
145
146 ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
147 TF_ASSIGN_OR_RETURN(
148 HloInstruction * to_deep_copy,
149 to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
150
151 // Add control edges between the respective kCopy instructions.
152 for (const auto& pair : from_copy_tree) {
153 const ShapeIndex& index = pair.first;
154 HloInstruction* from_copy = pair.second;
155 HloInstruction* to_copy = to_copy_tree.element(index);
156 if (from_copy == nullptr) {
157 TF_RET_CHECK(to_copy == nullptr);
158 continue;
159 }
160 TF_RET_CHECK(to_copy != nullptr);
161 TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
162 }
163
164 return std::make_pair(from_deep_copy, to_deep_copy);
165 }
166
167 // Compute the indices of the loop state which need copies in order to avoid
168 // live range interference. Generally, an element in the loop state does not
169 // need to be copied if the element is passed through transparently through the
170 // body.
171 //
172 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)173 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
174 const HloInstruction* xla_while,
175 ShapeTree<bool>* indices_to_copy) {
176 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
177
178 bool any_copies = false;
179 const HloInstruction* init = xla_while->operand(0);
180 for (auto& pair : *indices_to_copy) {
181 const ShapeIndex& index = pair.first;
182 bool& should_copy = pair.second;
183 // If there is any ambiguity, then loop state must be copied.
184 if (dataflow.GetValueSet(init, index).values().size() > 1 ||
185 dataflow.GetValueSet(xla_while, index).values().size() > 1) {
186 should_copy = true;
187 } else {
188 // If the output of the while instruction is not the same as the init
189 // value of the while, then this element is not passed through the body
190 // transparently and must be copied.
191 should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
192 dataflow.GetUniqueValueAt(init, index);
193 }
194 any_copies |= should_copy;
195 }
196 return any_copies;
197 }
198
199 // Compute the indices of the conditional outputs which need copies. Umambiguous
200 // buffers(buffer with only one value) don't need copies.
IndicesToCopyForConditional(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_conditional,ShapeTree<bool> * indices_to_copy)201 bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow,
202 const HloInstruction* xla_conditional,
203 ShapeTree<bool>* indices_to_copy) {
204 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(),
205 xla_conditional->shape()));
206
207 bool any_copies = false;
208 for (auto& pair : *indices_to_copy) {
209 const ShapeIndex& index = pair.first;
210 bool& should_copy = pair.second;
211
212 CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1);
213
214 auto value = dataflow.GetValueSet(xla_conditional, index).values()[0];
215 // The conditional must be copied if the value is a phi.
216 should_copy =
217 value->is_phi() && value->defining_instruction() == xla_conditional;
218 any_copies |= should_copy;
219 }
220 return any_copies;
221 }
222
223 // Add kCopy instructions around the given kWhile instruction to eliminate any
224 // possible live range interference of HLO values assuming a dependency-based
225 // ordering. Copies are added conservatively. There likely are copies which are
226 // not strictly necessary, but they are removed later in the pass via
227 // RemoveUnnecessaryCopies.
228 //
229 // Elements (each ShapeIndex) in the loop state are considered independently. A
230 // copy is added to each element of the loop state which is modified in the
231 // while body. For each such element, a total of three kCopy instructions are
232 // added at following locations:
233 //
234 // (1) The init value is copied before the kWhile instruction. Before:
235 //
236 // (Init)
237 // |
238 // kWhile
239 // |
240 // ...
241 //
242 // After:
243 //
244 // (Init)
245 // |
246 // kCopy
247 // |
248 // kWhile
249 // |
250 // ...
251 //
252 // This copy is necessary in case the init value is simultaneously live
253 // with the kWhile.
254 //
255 // (2) Copies are added to the parameter and root of the while body
256 // computation. Before:
257 //
258 // kParameter
259 // |
260 // ...
261 // |
262 // (body root)
263 //
264 // After:
265 //
266 // kParameter
267 // |
268 // kCopy ----------+
269 // | |
270 // ... ctrl
271 // | edge
272 // (body root) |
273 // | |
274 // kCopy <---------+
275 //
276 // The root kCopy becomes the new root of the computation. Both copies are
277 // necessary to any potential interference between the parameter value and
278 // the root value. The control edge prevents potential interference
279 // between the copies themselves.
280 //
281 // If the loop state is a tuple then the above kCopy instructions are a deep
282 // copy constructed of kCopy, kGetTupleElement, and kTuple instruction as
283 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)284 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
285 HloInstruction* xla_while) {
286 VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
287 TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
288
289 ShapeTree<bool> indices_to_copy(xla_while->shape());
290 if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
291 &indices_to_copy)) {
292 VLOG(2) << "No copies necessary for kWhile instruction "
293 << xla_while->name();
294 return Status::OK();
295 }
296
297 VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
298 for (auto& pair : indices_to_copy) {
299 if (pair.second) {
300 VLOG(2) << " " << pair.first;
301 }
302 }
303
304 // Deep copy init.
305 HloInstruction* while_init = xla_while->mutable_operand(0);
306 TF_ASSIGN_OR_RETURN(
307 HloInstruction * while_init_copy,
308 xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
309 TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
310
311 // Deep copy the parameter and the root. Extend a control edge from the copy
312 // of the parameter value to the corresponding copy value of the root.
313 HloComputation* body = xla_while->while_body();
314 HloInstruction* param = body->parameter_instruction(0);
315 HloInstruction* root = body->root_instruction();
316
317 // If param is the root then all indices should have been passed through the
318 // while body and we should have returned early above.
319 TF_RET_CHECK(param != root);
320
321 // Copy users before making a deep copy of the parameter as the deep copy
322 // will create new users of the parameter (eg, the GTE instructions of the
323 // deep copy).
324 std::vector<HloInstruction*> param_users = param->users();
325
326 TF_ASSIGN_OR_RETURN(auto pair,
327 DeepCopyAndAddControlEdges(param, root, indices_to_copy));
328
329 HloInstruction* param_copy = pair.first;
330 HloInstruction* root_copy = pair.second;
331
332 for (HloInstruction* user : param_users) {
333 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
334 }
335
336 body->set_root_instruction(root_copy);
337 return Status::OK();
338 }
339
340 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
341 // will remove the unnecessary copies.
AddCopiesForInPlaceOperation(const HloAliasAnalysis & alias_analysis,HloInstruction * in_place_op,int64_t operand_number)342 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
343 HloInstruction* in_place_op,
344 int64_t operand_number) {
345 VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
346 HloInstruction* operand = in_place_op->mutable_operand(operand_number);
347 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
348 in_place_op->parent()->DeepCopyInstruction(operand));
349 TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
350 return Status::OK();
351 }
352
353 // Conservatively adds copies before root instruction of entry computation and
354 // each aliased parameter to resolve interference of aliased input and output
355 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
356 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)357 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
358 HloComputation* entry = module->entry_computation();
359 HloInstruction* root = entry->root_instruction();
360
361 ShapeTree<bool> output_indices_to_copy(root->shape());
362 std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
363 entry->num_parameters());
364 bool has_alias = false;
365 for (auto* param : entry->parameter_instructions()) {
366 bool param_has_alias = false;
367 ShapeTree<bool> param_indices_to_copy(param->shape());
368
369 module->input_output_alias_config().ForEachAlias(
370 [&](const ShapeIndex& output_index,
371 const HloInputOutputAliasConfig::Alias& alias) {
372 if (alias.parameter_number == param->parameter_number()) {
373 param_has_alias = true;
374 *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
375 true;
376 *(output_indices_to_copy.mutable_element(output_index)) = true;
377 }
378 });
379
380 if (!param_has_alias) {
381 continue;
382 }
383
384 TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
385 TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
386
387 has_alias = true;
388 // Store a snapshot of users before DeepCopyInstruction, as
389 // DeepCopyInstruction introduces new users of the instruction.
390 std::vector<HloInstruction*> users = param->users();
391 ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
392 /*init_value=*/nullptr);
393 TF_ASSIGN_OR_RETURN(HloInstruction * copied,
394 entry->DeepCopyInstruction(
395 param, ¶m_indices_to_copy, ¶m_copy_tree));
396 for (HloInstruction* user : users) {
397 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
398 }
399
400 copied_parameters[param->parameter_number()] = param_copy_tree;
401 }
402
403 if (!has_alias) {
404 return Status::OK();
405 }
406
407 // Add copies before root instruction.
408 ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
409 /*init_value=*/nullptr);
410
411 TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
412 root->parent()->DeepCopyInstruction(
413 root, &output_indices_to_copy, &output_copy_tree));
414
415 // Add control dependencies between the input/output copies.
416 TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
417 [&](const ShapeIndex& output_index,
418 const HloInputOutputAliasConfig::Alias& alias) -> Status {
419 if (!copied_parameters[alias.parameter_number]) {
420 return Status::OK();
421 }
422 HloInstruction* from =
423 copied_parameters[alias.parameter_number]->element(
424 alias.parameter_index);
425 HloInstruction* to = output_copy_tree.element(output_index);
426
427 TF_RET_CHECK(from != nullptr);
428 TF_RET_CHECK(to != nullptr);
429 TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
430 return Status::OK();
431 }));
432
433 entry->set_root_instruction(root_copied);
434
435 return Status::OK();
436 }
437
438 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)439 Status StripControlDependenciesFrom(HloInstruction* instruction) {
440 while (!instruction->control_successors().empty()) {
441 TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
442 instruction->control_successors().front()));
443 }
444
445 while (!instruction->control_predecessors().empty()) {
446 TF_RETURN_IF_ERROR(
447 instruction->control_predecessors().front()->RemoveControlDependencyTo(
448 instruction));
449 }
450
451 return Status::OK();
452 }
453
454 class LiveRangeRegions {
455 public:
456 struct InstructionInfo {
InstructionInfoxla::__anonb2d115a80111::LiveRangeRegions::InstructionInfo457 InstructionInfo() : value_definition(nullptr), is_definition(false) {}
458
459 // The instruction that defines the value being used. It basically saves
460 // the defining instruction of each HloValue.
461 HloInstruction* value_definition;
462 // Whether the instruction defines a new value (or merely uses one). This
463 // basically remembers whether the instruction actually creates an HloValue
464 // or merely uses one, from a collection of given HloValues. Note that if
465 // is_definition = true, it merely says the instruction creates a new
466 // HloValue with or without defining a new one. For example, kAdd create a
467 // new HloValue (can be value_definition), but tuples or get-tuple-element,
468 // create a new HloValue aliasing without defining a new value (cannot be
469 // value_definition).
470 bool is_definition;
471 };
472 // Map instructions that use a value to the defining instruction of the value.
473 // Because all values must belong to the same live range, an instruction can
474 // have at most a single value-defining instruction; otherwise the multiple
475 // incoming active values would share a single buffer, which is not allowed.
476 // The value-defining and value-use instructions do not have to belong to the
477 // same computation, but the value use needs to be nested within the defining
478 // computation.
479 typedef absl::flat_hash_map<HloInstruction*, InstructionInfo> InstructionMap;
480 typedef std::pair<HloInstruction*, InstructionInfo> InstructionEntry;
481 // Map each computation to its immediately contained instructions.
482 typedef absl::flat_hash_map<const HloComputation*, InstructionMap>
483 ComputationMap;
484
operator [](const HloComputation * computation)485 InstructionMap& operator[](const HloComputation* computation) {
486 if (computation_map_.find(computation) == computation_map_.end()) {
487 computation_vector_.push_back(computation);
488 }
489 return computation_map_[computation];
490 }
491
operator [](const HloComputation * computation) const492 const InstructionMap& operator[](const HloComputation* computation) const {
493 ComputationMap::const_iterator p = computation_map_.find(computation);
494 CHECK(p != computation_map_.end());
495 return p->second;
496 }
begin() const497 ComputationMap::const_iterator begin() const {
498 return computation_map_.begin();
499 }
end() const500 ComputationMap::const_iterator end() const { return computation_map_.end(); }
size() const501 int64 size() const {
502 CHECK_EQ(computation_vector_.size(), computation_map_.size());
503 return computation_vector_.size();
504 }
empty() const505 bool empty() const { return size() == 0; }
Computation(int64_t index) const506 const HloComputation* Computation(int64_t index) const {
507 return computation_vector_[index];
508 }
contains(const HloInstruction * instr) const509 bool contains(const HloInstruction* instr) const {
510 CHECK_NE(instr, nullptr);
511 auto* computation = instr->parent();
512 auto p = computation_map_.find(computation);
513 if (p == computation_map_.end()) {
514 return false;
515 }
516 auto instr_map = (*p).second;
517 return instr_map.find(instr) != instr_map.end();
518 }
519
520 private:
521 ComputationMap computation_map_;
522 absl::InlinedVector<const HloComputation*, 5> computation_vector_;
523 };
524
525 namespace {
526 // Represent relations between the locations of two regions of instructions,
527 // each region can include 0-n instructions.
528 class Relation {
529 public:
530 enum RuntimeOrder {
531 // Indicate that there is no overlap whatsoever between the two regions.
532 kNoOverlap = 0,
533 // Indicate that the first region includes the same set of instructions as
534 // the second region.
535 kSameInstr = 1,
536 // Indicate that the first region is entirely before the second region
537 // starts.
538 kBeforeStart = 2,
539 // Indicate that the first region is before the second region ends.
540 kBeforeStartOrSameInstr = kBeforeStart | kSameInstr,
541 // Indicate that the first region is entirely after the second region ends.
542 kAfterEnd = 4,
543 // Indicate that the first region is after the second region
544 // starts, with some instructions before the second region ends.
545 kAfterEndOrSameInstr = kAfterEnd | kSameInstr,
546 // Indicate that the first region overlaps with the second one, but share no
547 // common instructions.
548 kBeforeStartOrAfterEnd = kBeforeStart | kAfterEnd,
549 // Indicate that the first region overlaps with the second one, and have
550 // some common instructions.
551 kBeforeOrAfterOrOverlap = kBeforeStart | kAfterEnd | kSameInstr,
552 };
Relation()553 Relation() : intercept_def_use_(false) {}
Relation(RuntimeOrder order,bool intercept_def_use=false)554 explicit Relation(RuntimeOrder order, bool intercept_def_use = false)
555 : intercept_def_use_(intercept_def_use) {
556 orders_.push_back(order);
557 }
Relation(const Relation & that)558 Relation(const Relation& that)
559 : intercept_def_use_(that.intercept_def_use_), orders_(that.orders_) {}
operator ==(const Relation & that) const560 bool operator==(const Relation& that) const {
561 return intercept_def_use_ == that.intercept_def_use_ &&
562 absl::c_equal(orders_, that.orders_);
563 }
564
565 // Return whether the runtime ordering may imply interception, assuming it
566 // models the relation between a modifying and a use instruction.
UseImpliesInterception() const567 bool UseImpliesInterception() const {
568 CHECK_EQ(orders_.size(), 1);
569 return UseImpliesInterception(orders_[0]);
570 }
571 // Return whether the runtime ordering may imply interception, assuming it
572 // models the relation between a modifying and a definition instruction.
DefinitionImpliesInterception() const573 bool DefinitionImpliesInterception() const {
574 CHECK_EQ(orders_.size(), 1);
575 return DefinitionImpliesInterception(orders_[0]);
576 }
577 // Return whether the current relation models a modifying instruction that
578 // intercepts the dataflow of another live range region.
InterceptDefUse() const579 bool InterceptDefUse() const { return intercept_def_use_; }
580 // Update interception state to the given value.
UpdateInterception(bool value)581 void UpdateInterception(bool value) {
582 CHECK_EQ(orders_.size(), 1);
583 intercept_def_use_ = value;
584 }
GetRuntimeOrder() const585 Relation::RuntimeOrder GetRuntimeOrder() const {
586 if (orders_.empty()) {
587 return Relation::kNoOverlap;
588 }
589 CHECK_EQ(orders_.size(), 1);
590 return orders_[0];
591 }
592 // Return whether the current relation implies two overlapping regions.
RuntimeOrderOverlap() const593 bool RuntimeOrderOverlap() const {
594 return absl::c_any_of(orders_, ImpliesOverlap);
595 }
RuntimeOrderIsUnordered() const596 bool RuntimeOrderIsUnordered() const {
597 return orders_.size() == 1 && orders_[0] == kBeforeStartOrAfterEnd;
598 }
RuntimeOrderIsNoOverlap() const599 bool RuntimeOrderIsNoOverlap() const {
600 return orders_.empty() || (orders_.size() == 1 && orders_[0] == kNoOverlap);
601 }
RuntimeOrderIsRunBefore() const602 bool RuntimeOrderIsRunBefore() const {
603 return orders_.size() == 1 && orders_[0] == kBeforeStart;
604 }
RuntimeOrderIsRunAfter() const605 bool RuntimeOrderIsRunAfter() const {
606 return orders_.size() == 1 && orders_[0] == kAfterEnd;
607 }
ToString() const608 std::string ToString() const {
609 return absl::StrCat("Interception = ", intercept_def_use_, ";",
610 absl::StrJoin(orders_, ","));
611 }
612
DefinitionImpliesInterception(RuntimeOrder definition)613 static bool DefinitionImpliesInterception(RuntimeOrder definition) {
614 return (definition == kAfterEnd || definition == kBeforeStartOrAfterEnd);
615 }
UseImpliesInterception(RuntimeOrder use)616 static bool UseImpliesInterception(RuntimeOrder use) {
617 return (use == kBeforeStart || use == kBeforeStartOrAfterEnd);
618 }
619
620 // Summarize additional relations into a single runtime ordering, assuming
621 // both relations are modeling constraints of the same source instruction.
UnionRelationFromSameSource(const Relation & rel)622 void UnionRelationFromSameSource(const Relation& rel) {
623 CHECK_LE(orders_.size(), 1);
624 CHECK_EQ(rel.orders_.size(), 1);
625 if (orders_.empty()) {
626 orders_.push_back(rel.orders_[0]);
627 } else {
628 orders_[0] = Union(orders_[0], rel.orders_[0]);
629 }
630 intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
631 }
632
633 // Summarize additional relations into disjoint runtime orderings, assuming
634 // the relations are modeling constraints of different source instructions.
UnionRelationFromDifferentSource(const Relation & rel)635 void UnionRelationFromDifferentSource(const Relation& rel) {
636 if (rel.orders_.empty()) {
637 return;
638 }
639 CHECK_EQ(rel.orders_.size(), 1);
640 intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
641 for (auto& local_order : orders_) {
642 if (OverwriteIfSubsume(rel.orders_[0], &local_order)) {
643 return;
644 }
645 }
646 orders_.push_back(rel.orders_[0]);
647 }
648
ReverseRuntimeOrder(RuntimeOrder order)649 static Relation::RuntimeOrder ReverseRuntimeOrder(RuntimeOrder order) {
650 switch (order) {
651 case kNoOverlap:
652 case kSameInstr:
653 case kBeforeStartOrAfterEnd:
654 case kBeforeOrAfterOrOverlap:
655 return order;
656 case kBeforeStart:
657 return kAfterEnd;
658 case kBeforeStartOrSameInstr:
659 return kAfterEndOrSameInstr;
660 case kAfterEnd:
661 return kBeforeStart;
662 case kAfterEndOrSameInstr:
663 return kBeforeStartOrSameInstr;
664 }
665 }
666
667 private:
668 // Indicate that the second region may intercept the def-use dataflow of the
669 // first region, if their buffers are combined.
670 bool intercept_def_use_;
671 // Remember the different runtime orderings of different instructions.
672 absl::InlinedVector<RuntimeOrder, 4> orders_;
673
Union(RuntimeOrder o1,RuntimeOrder o2)674 static RuntimeOrder Union(RuntimeOrder o1, RuntimeOrder o2) {
675 return static_cast<Relation::RuntimeOrder>(o1 | o2);
676 }
ImpliesOverlap(RuntimeOrder o)677 static bool ImpliesOverlap(RuntimeOrder o) {
678 return o >= RuntimeOrder::kBeforeStartOrAfterEnd;
679 }
680 // Returns whether ordering constraint o1 includes o2 as a subset, when they
681 // represent runtime orderings (interleavings) of two different regions.
Subsume(RuntimeOrder o1,RuntimeOrder o2)682 static bool Subsume(RuntimeOrder o1, RuntimeOrder o2) {
683 return Union(o1, o2) == o1;
684 }
685 // Overwrites o1 with o2 if o2 subsumes o1 (as defined above by the Subsume
686 // function). Return whether o2 is subsumed by the new value in o1.
OverwriteIfSubsume(RuntimeOrder o2,RuntimeOrder * o1)687 static bool OverwriteIfSubsume(RuntimeOrder o2, RuntimeOrder* o1) {
688 if (*o1 == o2) {
689 return true;
690 }
691 CHECK_NE(o1, nullptr);
692 // Overwrite o1 with o2 if it is subsumed by o2.
693 if (Subsume(o2, *o1)) {
694 *o1 = o2;
695 return true;
696 } else if (Subsume(*o1, o2)) {
697 // If o2 is already subsumed by o1, do nothing.
698 return true;
699 }
700 // If neither o1 nor o2 is subsumed by the other, return false, so that o2
701 // will be inserted as a separate entry representing all possible orderings.
702 return false;
703 }
704 };
705
706 class ComputeRelativeLocation {
707 public:
708 typedef LiveRangeRegions::InstructionEntry InstructionEntry;
ComputeRelativeLocation(HloOrdering * ordering)709 explicit ComputeRelativeLocation(HloOrdering* ordering)
710 : ordering_(ordering) {
711 VLOG(3) << "New analysis\n";
712 }
713
714 // Compute locationing constraints between two instructions. Here entry2 is
715 // the source instruction, in that the returned value describes the relation
716 // of entry2 in terms of whether it is before or after entry1, and whether it
717 // can intercept the def-use data flow of entry1.
Compute(const InstructionEntry & entry1,const InstructionEntry & entry2,bool instr2_can_modify)718 Relation Compute(const InstructionEntry& entry1,
719 const InstructionEntry& entry2, bool instr2_can_modify) {
720 auto def = entry1.second.value_definition;
721 auto use = entry1.first;
722 Relation::RuntimeOrder order =
723 ComputeRuntimeOrdering(entry2.first, entry1.first);
724 if (order == Relation::kSameInstr &&
725 entry1.second.is_definition != entry2.second.is_definition) {
726 if (entry1.second.is_definition) {
727 order = Relation::kBeforeStart;
728 } else {
729 order = Relation::kAfterEnd;
730 }
731 }
732 bool intercept = AlwaysForceInterception(entry2.first);
733 if (def == nullptr || !instr2_can_modify) {
734 return Relation(order, intercept);
735 }
736 // If the definition and use are parameter and return (root) of the parent
737 // computation, then any modification is considered intercepting.
738 if (def->opcode() == HloOpcode::kParameter &&
739 use == use->parent()->root_instruction()) {
740 VLOG(3) << "Setting interception due to parameter/root relation\n";
741 return Relation(order, true);
742 }
743 if (Relation::UseImpliesInterception(order)) {
744 auto order2 = ComputeRuntimeOrdering(entry2.first, def);
745 if (Relation::DefinitionImpliesInterception(order2)) {
746 VLOG(3) << "Setting interception for " << def->ToString()
747 << " with use:" << entry1.first->ToString() << "\n";
748 intercept = true;
749 }
750 }
751 return Relation(order, intercept);
752 }
753
754 // Return the relative locations (defined above) of range2 in relation to
755 // instructions in range1. Return kNoOverlap if range2 is outside of range1.
Compute(const LiveRangeRegions & range1,const LiveRangeRegions & range2)756 Relation Compute(const LiveRangeRegions& range1,
757 const LiveRangeRegions& range2) {
758 Relation dir_src_dest;
759 for (int64_t index = 0; index < range1.size(); index++) {
760 auto* computation1 = range1.Computation(index);
761 for (const auto& computation_entry2 : range2) {
762 auto* computation2 = computation_entry2.first;
763 for (auto instr_entry2 : computation_entry2.second) {
764 if (!ordering_->call_graph().Dominates(computation1, computation2)) {
765 continue;
766 }
767 VLOG(3) << "Locationing " << instr_entry2.first->ToString();
768 // Saves relations between instr2 and other instructions in range1.
769 bool instr2_can_modify =
770 InstructionCanIntercept(instr_entry2, range1);
771 Relation instr2_relation;
772 std::vector<InstructionEntry> unordered_ops;
773 bool unordered_intercept = false;
774 for (auto instr_entry1 : range1[computation1]) {
775 auto rel = Compute(instr_entry1, instr_entry2, instr2_can_modify);
776 VLOG(3) << "new relation with:" << instr_entry1.first->ToString()
777 << " = " << rel.ToString() << "\n";
778 if (!rel.RuntimeOrderIsUnordered()) {
779 instr2_relation.UnionRelationFromSameSource(rel);
780 } else {
781 unordered_ops.push_back(instr_entry1);
782 unordered_intercept |= rel.InterceptDefUse();
783 }
784 VLOG(3) << "instr2 relation:" << instr2_relation.ToString() << "\n";
785 }
786 // Here instru2_relation is guaranteed to have at most a single entry,
787 // because it was initialized to be empty, and has been updated only
788 // via instr2_relation.UnionRelationFromSameSource(rel), which
789 // maintains that the updated result has only a single entry.
790 if (!ForceRuntimeOrder(unordered_ops, instr_entry2,
791 instr2_relation.GetRuntimeOrder())) {
792 VLOG(3) << "Unable to force ordering of unordered ops\n";
793 instr2_relation.UnionRelationFromSameSource(Relation(
794 Relation::kBeforeStartOrAfterEnd, unordered_intercept));
795 }
796 dir_src_dest.UnionRelationFromDifferentSource(instr2_relation);
797 VLOG(3) << "Resulting relation : " << dir_src_dest.ToString() << "\n";
798 }
799 }
800 }
801 return dir_src_dest;
802 }
803
804 // Return whether control dependences, if exist, are added successfully.
AddControlDependenceForUnorderedOps()805 bool AddControlDependenceForUnorderedOps() {
806 if (ctrl_deps_.empty()) {
807 return true;
808 }
809 PredecessorHloOrdering* ordering =
810 dynamic_cast<PredecessorHloOrdering*>(ordering_);
811 if (ordering == nullptr) {
812 // Support force ordering of unordered-ops only when using predecssor
813 // ordering.
814 return false;
815 }
816 for (const auto& comp_it : ctrl_deps_) {
817 HloComputation* parent = comp_it.first;
818 HloReachabilityMap& reachability_map = ordering->reachability_map(parent);
819 for (const auto& instr_it : comp_it.second) {
820 HloInstruction* entry1 = instr_it.first;
821 for (HloInstruction* entry2 : instr_it.second) {
822 VLOG(3) << "Add control dependence between " << entry2->ToString();
823 VLOG(3) << "\n vs " << entry1->ToString() << "\n";
824 TF_CHECK_OK(entry2->AddControlDependencyTo(entry1));
825 }
826 reachability_map.UpdateReachabilityThroughInstruction(entry1);
827 for (HloInstruction* entry2 : instr_it.second) {
828 DCHECK(ordering_->GetExecutionConstraint(entry1, entry2) ==
829 HloOrdering::ExecutionConstraint::kRunAfter);
830 }
831 }
832 }
833 return true;
834 }
835
836 private:
837 enum ComputeStatus {
838 kFullyComputed,
839 kPartiallyComputed,
840 kNotComputed,
841 };
842 typedef std::pair<ComputeStatus, Relation::RuntimeOrder> SavedRelation;
843
844 // Returns whether it is safe to force the desired_relation ordering between
845 // all operations in unordered_ops and entry2. If safe, save the new enforced
846 // ordering relations.
ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,const InstructionEntry entry2,Relation::RuntimeOrder desired_relation)847 bool ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,
848 const InstructionEntry entry2,
849 Relation::RuntimeOrder desired_relation) {
850 if (unordered_ops.empty()) {
851 return true;
852 }
853 if (desired_relation != Relation::kBeforeStart &&
854 desired_relation != Relation::kAfterEnd) {
855 return false;
856 }
857 auto ModifiesNonCopy = [](HloInstruction* instr, const HloInstruction* op) {
858 auto in_place = HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr);
859 if (in_place.empty()) {
860 return false;
861 }
862 return absl::c_any_of(
863 in_place, [&](const std::pair<HloUse, ShapeIndex>& a) {
864 auto* op2 = instr->operand(a.first.operand_number);
865 return (op == nullptr) ? (op2->opcode() == HloOpcode::kCopy)
866 : (op2 == op);
867 });
868 };
869 for (const InstructionEntry& entry1 : unordered_ops) {
870 // Only consider instructions in the same computation.
871 if (entry1.first->parent() != entry2.first->parent()) {
872 return false;
873 }
874 HloInstruction* pred = (desired_relation == Relation::kBeforeStart)
875 ? entry2.first
876 : entry1.first;
877 HloInstruction* succ = (desired_relation == Relation::kBeforeStart)
878 ? entry1.first
879 : entry2.first;
880 if (pred == pred->parent()->root_instruction()) {
881 return false;
882 }
883 if (succ->opcode() == HloOpcode::kCopy &&
884 ModifiesNonCopy(pred, succ->operand(0))) {
885 VLOG(3) << "Failed to force unordered op ordering due to copy ordering "
886 << " between " << pred->ToString() << "\n";
887 VLOG(3) << " vs. " << succ->ToString() << "\n";
888 return false;
889 }
890 }
891 for (const InstructionEntry& entry1 : unordered_ops) {
892 Save(entry2.first, entry1.first, desired_relation, true);
893 }
894 return true;
895 }
896
AlwaysForceInterception(HloInstruction * instr)897 static bool AlwaysForceInterception(HloInstruction* instr) {
898 // The following communication operations can have some unexpected side
899 // effects, when synchronizing across processes. Therefore, we
900 // conservatively try provide dedicated buffers to these operations instead
901 // of allowing them to share buffers with other operations, as the reuse may
902 // cause unexpected interferences.
903 if (HloDataflowAnalysis::IsAsynchronousOperationStart(instr->opcode()) ||
904 HloDataflowAnalysis::IsAsynchronousOperationDone(instr->opcode())) {
905 return true;
906 }
907 switch (instr->opcode()) {
908 // TODO(b/190903339): It appears that collectivePermute needs to be
909 // followed by a copy when escaping through a computation root.
910 case HloOpcode::kCollectivePermute:
911 return true;
912 default:
913 return false;
914 }
915 }
916
917 // Returns whether the given instr may intercept the def-use flow of another
918 // ongoing live range if its buffer is combined with the other live range.
919 // The function should return true if instr creates a new HloValue that could
920 // overwrite an existing HloValue in the combined buffer.
921 // More specifically, here we are looking for operations that create new
922 // values, e.g., add, subtract, in contrast to HLOs that merely create
923 // aliasings among existing values, e.g., tuple, get-tuple-element. Any of the
924 // new values created by operations such as add or subtract, when included as
925 // definition operations in a live range, are aliases of the buffer to be
926 // allocated to the live range and so are treated as they may be modifying the
927 // targeting buffer.
InstructionCanIntercept(const InstructionEntry & entry,const LiveRangeRegions & region)928 bool InstructionCanIntercept(const InstructionEntry& entry,
929 const LiveRangeRegions& region) {
930 auto instr = entry.first;
931 if (!entry.second.is_definition) {
932 // If the instruction only uses the value, it can intercept only if it
933 // modifies the buffer in place.
934 return !HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr).empty();
935 }
936 switch (instr->opcode()) {
937 // If the copy instruction is used to connect two live range regions,
938 // it does not overwrite the combined buffer with new values.
939 case HloOpcode::kCopy:
940 // Checking the copy simply copies from the other live range with no
941 // layout conflicts.
942 if (region.contains(instr->operand(0)) &&
943 ShapeUtil::Equal(instr->shape(), instr->operand(0)->shape())) {
944 return false; // Cannot intercept.
945 }
946 return true;
947 // The following operations merely create aliases among the HloValues.
948 case HloOpcode::kParameter:
949 case HloOpcode::kTuple:
950 case HloOpcode::kGetTupleElement:
951 // Here we consider all the compound operations (e.g., conditionals and
952 // while loops) as if they do not modify any HloValue, with the argument
953 // being that any value modifying operation contained inside will be
954 // considered separately to make sure the kIntercept relation being
955 // recorded as appropriate. Since the compound operations may or may not
956 // modify, not treating them as value modifying would make the algorithm
957 // less conservative.
958 case HloOpcode::kWhile:
959 case HloOpcode::kCall:
960 case HloOpcode::kConditional:
961 case HloOpcode::kTupleSelect:
962 return false;
963 default:
964 return true;
965 }
966 return true;
967 }
968
AlreadyComputed(HloInstruction * op1,HloInstruction * op2)969 SavedRelation AlreadyComputed(HloInstruction* op1, HloInstruction* op2) {
970 auto p2 = saved_relations_.find(op2);
971 if (p2 != saved_relations_.end()) {
972 auto p1 = (*p2).second.find(op1);
973 if (p1 != (*p2).second.end()) {
974 return SavedRelation(kFullyComputed, (*p1).second);
975 }
976 }
977 p2 = saved_relations_.find(op1);
978 if (p2 != saved_relations_.end()) {
979 auto p1 = (*p2).second.find(op2);
980 if (p1 != (*p2).second.end()) {
981 return SavedRelation(kPartiallyComputed,
982 Relation::ReverseRuntimeOrder((*p1).second));
983 }
984 }
985 return SavedRelation(kNotComputed, Relation::kNoOverlap);
986 }
987
Save(HloInstruction * entry1,HloInstruction * entry2,const Relation::RuntimeOrder relation,bool is_unordered_originally=false)988 Relation::RuntimeOrder Save(HloInstruction* entry1, HloInstruction* entry2,
989 const Relation::RuntimeOrder relation,
990 bool is_unordered_originally = false) {
991 CHECK_EQ(AlreadyComputed(entry1, entry2).first, kNotComputed);
992 // Do not save unordered relations.
993 CHECK_NE(relation, Relation::kBeforeStartOrAfterEnd);
994 saved_relations_[entry2][entry1] = relation;
995 if (is_unordered_originally) {
996 CHECK(relation == Relation::kBeforeStart ||
997 relation == Relation::kAfterEnd)
998 << relation;
999 HloInstruction* pred =
1000 (relation == Relation::kBeforeStart) ? entry1 : entry2;
1001 HloInstruction* succ =
1002 (relation == Relation::kBeforeStart) ? entry2 : entry1;
1003 VLOG(3) << "Save unordered relation: " << pred->ToString() << "\n";
1004 VLOG(3) << " vs " << succ->ToString() << "\n";
1005 CHECK_EQ(succ->parent(), pred->parent());
1006 auto& dep_vec = ctrl_deps_[succ->parent()][succ];
1007 for (HloInstruction*& op : dep_vec) {
1008 auto rel = AlreadyComputed(pred, op);
1009 if (rel.first != kNotComputed) {
1010 if (rel.second == Relation::kAfterEnd) {
1011 op = pred;
1012 } else {
1013 CHECK(rel.second == Relation::kBeforeStart);
1014 }
1015 return relation;
1016 }
1017 }
1018 VLOG(2) << "Forcing unordered:" << pred->ToString() << "\n";
1019 VLOG(2) << " vs " << succ->ToString() << "\n";
1020 dep_vec.push_back(pred);
1021 }
1022 return relation;
1023 }
1024
1025 // Compute the runtime ordering constraints between two instructions.
ComputeRuntimeOrdering(HloInstruction * instr1,HloInstruction * instr2)1026 Relation::RuntimeOrder ComputeRuntimeOrdering(HloInstruction* instr1,
1027 HloInstruction* instr2) {
1028 auto saved_relation = AlreadyComputed(instr1, instr2);
1029 if (saved_relation.first != kNotComputed) {
1030 VLOG(3) << "Already computed between " << instr1->ToString() << "\n vs "
1031 << instr2->ToString() << "\n";
1032 return saved_relation.second;
1033 }
1034 auto constraint = ordering_->GetExecutionConstraint(instr1, instr2);
1035 switch (constraint) {
1036 case HloOrdering::ExecutionConstraint::kIsSame:
1037 return Save(instr1, instr2, Relation::kSameInstr);
1038 case HloOrdering::ExecutionConstraint::kRunBeforeEnd:
1039 return Save(instr1, instr2, Relation::kBeforeStartOrSameInstr);
1040 case HloOrdering::ExecutionConstraint::kRunBeforeStart:
1041 return Save(instr1, instr2, Relation::kBeforeStart);
1042 case HloOrdering::ExecutionConstraint::kRunAfter:
1043 return Save(instr1, instr2, Relation::kAfterEnd);
1044 case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
1045 case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
1046 return Save(instr1, instr2, Relation::kNoOverlap);
1047 case HloOrdering::ExecutionConstraint::kUnordered: {
1048 if (instr1->parent() != instr2->parent()) {
1049 return Relation::kBeforeStartOrAfterEnd;
1050 }
1051 auto ControlDependenceBefore = [&](HloInstruction* op1,
1052 HloInstruction* op2) {
1053 auto constraint = ComputeRuntimeOrdering(op1, op2);
1054 if (constraint == Relation::kBeforeStart ||
1055 constraint == Relation::kSameInstr ||
1056 constraint == Relation::kBeforeStartOrSameInstr) {
1057 return true;
1058 } else {
1059 return false;
1060 }
1061 };
1062 if (!ctrl_deps_.empty()) {
1063 auto ctrl_deps = ctrl_deps_[instr1->parent()];
1064 if (absl::c_any_of(ctrl_deps[instr2], [&](HloInstruction* pred2) {
1065 return ControlDependenceBefore(instr1, pred2);
1066 })) {
1067 VLOG(2) << "control-dependent: " << instr1->ToString() << "\n";
1068 VLOG(2) << "vs " << instr2->ToString() << "\n";
1069 return Save(instr1, instr2, Relation::kBeforeStart);
1070 } else if (absl::c_any_of(
1071 ctrl_deps[instr1], [&](HloInstruction* pred1) {
1072 return ControlDependenceBefore(instr2, pred1);
1073 })) {
1074 VLOG(2) << "control-dependent: " << instr2->ToString() << "\n";
1075 VLOG(2) << "vs " << instr1->ToString() << "\n";
1076 return Save(instr1, instr2, Relation::kAfterEnd);
1077 }
1078 }
1079 // Don't save the result for unordered operations, so they can be
1080 // refined later.
1081 return Relation::kBeforeStartOrAfterEnd;
1082 }
1083 }
1084 }
1085
1086 HloOrdering* ordering_;
1087 absl::flat_hash_map<
1088 HloInstruction*,
1089 absl::flat_hash_map<HloInstruction*, Relation::RuntimeOrder>>
1090 saved_relations_;
1091 absl::flat_hash_map<
1092 HloComputation*,
1093 absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>>
1094 ctrl_deps_;
1095 };
1096 } // namespace
1097
1098 // Class which tracks the HLO values within each HLO buffer in the module
1099 // during copy removal.
1100 //
1101 // The values are held in a linked list where there is one list for each
1102 // buffer. Removing a copy instruction merges together the values in the
1103 // source buffer of the copy to the destination buffer of the copy. This class
1104 // tracks these value lists as copies are removed from the graph (and value
1105 // lists are merged).
1106 //
1107 // The CopyRemover object is initialized to match the state of
1108 // HloAliasAnalysis. However, as copies are removed this state diverges. The
1109 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
1110 // a fully updatable alias analysis is very slow.
1111 class CopyRemover {
1112 public:
1113 // The values held in a single HLO buffer are represented using a linked
1114 // list. An element type in this list is ValueNode.
1115 //
1116 // This linked list is hand-rolled to enable efficient splicing of lists
1117 // using only references to list elements without knowing which lists are
1118 // being spliced. std::list requires a reference to the list object to
1119 // splice.
1120 struct ValueNode {
ValueNodexla::__anonb2d115a80111::CopyRemover::ValueNode1121 explicit ValueNode(const HloValue* v) : value(v) {}
1122
1123 const HloValue* value;
1124
1125 // The uses are maintained outside of HloValue::uses() because
1126 // HloValue::uses() is not updatable (a fully updatable dataflow analysis
1127 // is slow).
1128 std::vector<const HloUse*> uses;
1129
1130 // next/prev elements in the linked list. The list is circularly linked so
1131 // these values are never null for elements in the list.
1132 ValueNode* prev = nullptr;
1133 ValueNode* next = nullptr;
1134 };
1135
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,HloOrdering * ordering,bool check_live_range_ordering)1136 CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
1137 HloOrdering* ordering, bool check_live_range_ordering)
1138 : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
1139 // Construct a list for each HLO buffer in the alias analysis. Maintain a
1140 // map from HloValue to the respective list element representing that
1141 // value. The map is used to construct the copy info map below.
1142 absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
1143 // Perform check only if the default dependence-based ordering is used.
1144 for (const HloBuffer& buffer : alias_analysis.buffers()) {
1145 // No copies should have been inserted within fused computations, so no
1146 // need to remove them. HloOrdering isn't compatible with HloValues inside
1147 // fusions, so skip copy removal for them.
1148 if (buffer.values().at(0)->defining_instruction()->IsFused()) {
1149 continue;
1150 }
1151 if (check_live_range_ordering) {
1152 // Verify values contained in the buffer are strictly ordered. This
1153 // should always be the case after adding copies to eliminate
1154 // interference. Specifically, the addition of the control flow edges
1155 // between copies added around aliased operations (kWhile) guarantees
1156 // this strict order.
1157 for (const HloValue* value_a : buffer.values()) {
1158 if (value_a->shape().IsToken()) {
1159 // Token values have no representation and cannot interfere.
1160 continue;
1161 }
1162 for (const HloValue* value_b : buffer.values()) {
1163 if (value_a != value_b) {
1164 DCHECK(ordering_->LiveRangeStrictlyBefore(
1165 *value_a, *value_b, dataflow_,
1166 /*use_is_always_before_def_in_same_instr=*/true) ||
1167 ordering_->LiveRangeStrictlyBefore(
1168 *value_b, *value_a, dataflow_,
1169 /*use_is_always_before_def_in_same_instr=*/true))
1170 << value_a->ToString() << " and " << value_b->ToString()
1171 << " are not ordered";
1172 }
1173 }
1174 }
1175 }
1176
1177 std::vector<const HloValue*> values = buffer.values();
1178 absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
1179 return ordering_->IsDefinedBefore(*a, *b);
1180 });
1181
1182 // Create a list containing all of the values in the buffer.
1183 AddValueList(values, &value_to_node);
1184 }
1185
1186 // Create copy_map_ which contains the source and destination values
1187 // of all copies.
1188 CreateCopyMap(module, value_to_node);
1189
1190 XLA_VLOG_LINES(3, ToString());
1191 TF_DCHECK_OK(Verify());
1192 }
1193
1194 // Add a list containing the given values to CopyRemover. This
1195 // represents the values contained in a single buffer. For each value in
1196 // 'values' an entry is created in value_to_node which indicates the
1197 // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)1198 void AddValueList(
1199 absl::Span<const HloValue* const> values,
1200 absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
1201 ValueNode* tail = nullptr;
1202 ValueNode* head = nullptr;
1203 for (const HloValue* value : values) {
1204 auto new_node = new ValueNode(value);
1205 (*value_to_node)[value] = new_node;
1206
1207 // Copy the HLO values's uses into the ValueNode for the value. These
1208 // uses in ValueNode are updated as copies are removed.
1209 new_node->uses.reserve(value->uses().size());
1210 for (const HloUse& use : value->uses()) {
1211 new_node->uses.push_back(&use);
1212 }
1213
1214 // Connect the new node into the linked list.
1215 if (tail == nullptr) {
1216 head = new_node;
1217 } else {
1218 tail->next = new_node;
1219 new_node->prev = tail;
1220 }
1221 tail = new_node;
1222 }
1223
1224 // The linked list is circular so connect the head and tail.
1225 tail->next = head;
1226 head->prev = tail;
1227 value_lists_.insert(head);
1228 }
1229
1230 // This method also fills in copy_map_ which indicates which nodes
1231 // in the value lists corresponding to the source and destination values of
1232 // kCopy instructions. value_to_node should map each HloValue to its
1233 // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)1234 void CreateCopyMap(
1235 const HloModule& module,
1236 const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
1237 for (HloComputation* computation : module.MakeNonfusionComputations()) {
1238 for (HloInstruction* instruction : computation->instructions()) {
1239 // Add copies with unambiguous source values to the map. Copies with
1240 // ambiguous sources are not removable.
1241 if (instruction->opcode() == HloOpcode::kCopy) {
1242 const HloValueSet& src_value_set =
1243 dataflow_.GetValueSet(instruction->operand(0));
1244 if (src_value_set.values().size() == 1) {
1245 CopyNodes& copy_node = copy_map_[instruction];
1246 copy_node.dest =
1247 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
1248 copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
1249 }
1250 }
1251 }
1252 }
1253 }
1254
~CopyRemover()1255 ~CopyRemover() {
1256 for (const ValueNode* head : value_lists_) {
1257 const ValueNode* p = head;
1258 do {
1259 const ValueNode* tmp = p->next;
1260 delete p;
1261 p = tmp;
1262 } while (p != head);
1263 }
1264 }
1265
1266 // Verify invariants within the linked lists.
Verify() const1267 Status Verify() const {
1268 for (const ValueNode* head : value_lists_) {
1269 const ValueNode* p = head;
1270 do {
1271 // Verify links between elements are consistent.
1272 TF_RET_CHECK(p->prev->next == p);
1273 TF_RET_CHECK(p->next->prev == p);
1274
1275 const HloInstruction* def = p->value->defining_instruction();
1276 if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
1277 TF_RET_CHECK(copy_map_.at(def).dest == p);
1278 }
1279 for (const HloUse* use : p->uses) {
1280 if (use->instruction->opcode() == HloOpcode::kCopy &&
1281 ContainsKey(copy_map_, use->instruction)) {
1282 TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
1283 }
1284 }
1285
1286 p = p->next;
1287 } while (p != head);
1288 }
1289 return Status::OK();
1290 }
1291
1292 // Compute the set of instructions where values are alive and organize these
1293 // instructions by separating them into their respective computations.
ComputeLiveRangeRegions(const ValueNode * head)1294 LiveRangeRegions ComputeLiveRangeRegions(const ValueNode* head) {
1295 LiveRangeRegions live_range;
1296
1297 auto VisitValueNode = [&](const ValueNode* node) {
1298 HloInstruction* def_op = node->value->instruction();
1299 HloComputation* def_parent = def_op->parent();
1300 live_range[def_parent][def_op].is_definition = true;
1301 for (const auto& use : node->uses) {
1302 auto* use_op = use->instruction;
1303 HloComputation* use_parent = use_op->parent();
1304 live_range[use_parent][use_op].value_definition = def_op;
1305 }
1306 };
1307 ForEachValueInRange(head, VisitValueNode);
1308 return live_range;
1309 }
1310
1311 // Try to elide the given copy. Elision of a copy is possible only if no
1312 // live range interference is introduced by the copy's elimination. If
1313 // elision is possible, then the internal state (value lists) are updated,
1314 // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy,bool use_region_analysis)1315 bool TryElideCopy(const HloInstruction* copy, bool use_region_analysis) {
1316 VLOG(2) << "Trying to remove " << copy->name();
1317
1318 if (!ContainsKey(copy_map_, copy)) {
1319 VLOG(2) << copy->name() << " is not removable";
1320 return false;
1321 }
1322 if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
1323 VLOG(2) << copy->name() << " is not removable (shape mismatch)";
1324 return false;
1325 }
1326 const CopyNodes& copy_node = copy_map_.at(copy);
1327 DCHECK(copy_node.src != nullptr);
1328 DCHECK(copy_node.dest != nullptr);
1329
1330 VLOG(3) << copy->name() << " copies value "
1331 << copy_node.src->value->ToShortString();
1332 VLOG(3) << "Source buffer values: " << ValueListToString(copy_node.src);
1333 VLOG(3) << "Dest buffer values: " << ValueListToString(copy_node.dest);
1334 // Checks whether the live range at src is before that defined by dest.
1335 auto CheckLiveRangeBefore = [&](ValueNode* src, ValueNode* dest) {
1336 for (ValueNode* next_dest = dest; next_dest != nullptr;
1337 next_dest = Next(*next_dest)) {
1338 for (ValueNode* prev_src = src; prev_src != nullptr;
1339 prev_src = Prev(*prev_src)) {
1340 if (!LiveRangeBefore(*prev_src, *next_dest)) {
1341 VLOG(2) << "Live range of " << prev_src->value->ToShortString()
1342 << " is not before " << next_dest->value->ToShortString();
1343 return false;
1344 }
1345 }
1346 }
1347 return true;
1348 };
1349 auto CheckLiveRangeInterference = [&](ValueNode* src, ValueNode* dest,
1350 const CombineLiveRangeOption option) {
1351 CHECK_NE(src, nullptr);
1352 CHECK_NE(dest, nullptr);
1353 if (!use_region_analysis) {
1354 VLOG(2) << "Configured to not use region-based analysis.\n";
1355 return true;
1356 }
1357 if (ValuesInterfere(src, dest, option)) {
1358 VLOG(2) << "Region-based interference is true. \n";
1359 return true;
1360 }
1361 VLOG(2) << "Region-based interference is false. \n";
1362 return false;
1363 };
1364
1365 // A kCopy instruction copies an HLO value from a source buffer and
1366 // defines an HLO value in a destination buffer. Most generally, the
1367 // source and destination buffers may each hold more than one value at
1368 // different points in the computation so we define the following:
1369 //
1370 // Values in source buffer: {s_0, ..., s_n}
1371 // Values in destination buffer: {d_0, ..., d_m}
1372 //
1373 // A kCopy instruction between these buffers copies a value s_x in the
1374 // source buffer and defines a value d_y in the destination buffer. The
1375 // elision of a copy merges the source and destination buffers together,
1376 // so the list of values for the source and destination buffers are
1377 // merged.
1378 //
1379 // We handle two different cases for copy elision:
1380 //
1381 // (1) the kCopy defines the first value in the destination buffer (d_0).
1382 //
1383 // (2) the kCopy copies the last value in the source buffer (s_n).
1384 //
1385 // For the remaining case where the kCopy copies a not-last value from the
1386 // source buffer to a not-first value of the destination buffer, the kCopy
1387 // instruction cannot be removed. This case is generated, for example, if
1388 // the kCopy copies a while body parameter of the loop state at one tuple
1389 // index to a different tuple index in the while body root. Removal of the
1390 // copy necessarily results in live range interference of values in the
1391 // loop state at the two different tuple indices.
1392 //
1393 // We can only perform copy elision if the resulting merged values have
1394 // totally ordered live ranges; otherwise the merged buffer would have
1395 // live range interference.
1396 if (copy_node.src->next == copy_node.dest) {
1397 // In the process of eliding copies, its possible for a copy to have the
1398 // same source and destination buffer. In this case, the copy can be
1399 // safely removed.
1400 VLOG(2) << copy->name() << " source and destination buffers are same.";
1401 } else if (IsHead(*copy_node.dest)) {
1402 // The copy copies an arbitrary value in the source buffer (call it s_x)
1403 // and defines d_0, the first value in the destination buffer. After
1404 // merging, the values in the combined buffer must be strictly ordered
1405 // as follows** to elide the copy:
1406 //
1407 // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
1408 //
1409 // Removing the copy eliminates d_0, and uses of d_0 become uses of
1410 // s_x. In the above ordering, the live range of d_m will be ordered
1411 // before the live range of s_{x+1} and the definition and all uses of
1412 // s_x will be ordered before the definition of d_1. To make sure the
1413 // copy elision is safe, the following code checks that this ordering is
1414 // valid --- in particular we check it is safe to order d_m ahead of all
1415 // the liverages at and after x_{x+1}, and it is safe to order all uses
1416 // of s_x before the definition of d_1, by checking the live range
1417 // constraints for each pair --- we cannot skip the later checks because
1418 // the live range ordering is not guranteed to be transitive --- while it
1419 // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
1420 // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
1421 // because the exclusiveness relation of non-overlapping computations is
1422 // not transitive.
1423 //
1424 // ** Technically it might be possible to have a non-interfering
1425 // non-trivial interleaving of the values of the source and
1426 // destination buffers in the resulting order. This can be potentially
1427 // supported in the ValuesInterfere function, which performs
1428 // interference analysis at a more global scope than the alternative
1429 // LiveRangeBefore analysis which requires strict ordering of all live
1430 // ranges. Currently, however, this is not yet supported, as
1431 // we simply check for the case where *all* values of the destination
1432 // buffer (d_1 through d_m) are spliced into the point where the copy
1433 // used to be.
1434 VLOG(2) << copy->name() << " defines the first value in its buffer";
1435 bool live_range_before =
1436 // Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
1437 CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)) &&
1438 // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
1439 CheckLiveRangeBefore(copy_node.dest->prev, Next(*copy_node.src));
1440 VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1441 if (!live_range_before &&
1442 CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1443 kMergeFirstDestInSource)) {
1444 return false;
1445 }
1446 VLOG(2) << "Splice dest after source.";
1447 // Splice in destination buffer values list right after 'src'.
1448 SpliceAfter(copy_node.dest, copy_node.src);
1449 } else if (IsTail(*copy_node.src)) {
1450 // The copy copies the last value in the source buffer, s_n, and defines
1451 // an arbitrary value in the destination buffer, d_y. After
1452 // merging, the values in the combined buffer must be strictly ordered
1453 // as follows** to elide the copy:
1454 //
1455 // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
1456 //
1457 // Removing the copy eliminates d_y, and uses of d_y become uses of
1458 // s_n. To enforce the above order, the live range of d_{y-1} must be
1459 // before the live range of s_0, and the live range of s_n must be
1460 // before the live range of d_{y+1}.
1461 //
1462 // ** See comment above in the code handling Case (1).
1463 VLOG(2) << copy->name() << " copies the last value ("
1464 << copy_node.src->value->ToShortString() << ") in its buffer";
1465 bool live_range_before =
1466 // Live range of d_0, ..., d_{y-1} must be before s_0;
1467 CheckLiveRangeBefore(Prev(*copy_node.dest), copy_node.src->next) &&
1468 // Live range of 'last_src' must be before next_dest d_{y+1}.
1469 CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest));
1470 VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1471 if (!live_range_before &&
1472 CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1473 kMergeLastSourceInDest)) {
1474 VLOG(2) << "Region-based analysis concludes interference.\n";
1475 return false;
1476 }
1477 VLOG(2) << "Splice src after prev of dest.";
1478 // Splice source buffer values list right after 'prev_dest'.
1479 SpliceAfter(copy_node.src->next, Prev(*copy_node.dest));
1480 } else {
1481 VLOG(2) << copy->name()
1482 << " copies value in middle of source buffer to value in middle "
1483 "of destination buffer";
1484 return false;
1485 }
1486
1487 RemoveCopyValue(copy_node.dest);
1488
1489 XLA_VLOG_LINES(4, ToString());
1490 TF_DCHECK_OK(Verify());
1491
1492 return true;
1493 }
1494
1495 // Delete the given ValueNode associated with a elided kCopy
1496 // instruction. This should be called after splicing the value lists of the
1497 // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)1498 void RemoveCopyValue(ValueNode* copy_value_node) {
1499 CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
1500 HloOpcode::kCopy);
1501 ValueNode* operand_node = copy_value_node->prev;
1502 CHECK(operand_node != copy_value_node);
1503
1504 VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
1505 << " => " << copy_value_node->value->ToShortString();
1506
1507 // Splice out the copy value node.
1508 operand_node->next = copy_value_node->next;
1509 copy_value_node->next->prev = operand_node;
1510
1511 // Patch up uses. Remove use of copy from operand_node uses.
1512 auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
1513 const HloUse* use) {
1514 return use->instruction == copy_value_node->value->defining_instruction();
1515 });
1516 CHECK(it != operand_node->uses.end());
1517 operand_node->uses.erase(it);
1518
1519 // If the elided copy has any uses which are themselves kCopy instructions
1520 // then patch up the copy info to reflect the that this kCopy instruction
1521 // has a different operand (the operand of the elided copy).
1522 for (const HloUse* copy_use : copy_value_node->uses) {
1523 operand_node->uses.push_back(copy_use);
1524 if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
1525 ContainsKey(copy_map_, copy_use->instruction)) {
1526 copy_map_.at(copy_use->instruction).src = operand_node;
1527 }
1528 }
1529
1530 // Delete the copy info and the value node.
1531 copy_map_.erase(copy_value_node->value->defining_instruction());
1532 delete copy_value_node;
1533 }
1534
1535 // Returns true if the live range of given value 'a' is before the live
1536 // range of 'b'.
1537 //
1538 // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
1539 // updated as copies are removed. Also here because the result is used
1540 // to directly drive copy elision, use_is_always_before_def_in_same_instr is
1541 // set to false.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)1542 bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
1543 if (a.uses.empty()) {
1544 VLOG(2) << "Empty uses for " << *a.value;
1545 return ordering_->IsDefinedBefore(*a.value, *b.value);
1546 }
1547 VLOG(3) << "Checking live ranges before :" << ValueListToString(&a)
1548 << " vs " << ValueListToString(&b) << "\n";
1549 return ordering_->UsesBeforeValueDefinition(
1550 a.uses, *b.value, dataflow_,
1551 /* use_is_always_before_def_in_same_instr=*/false);
1552 }
1553
1554 // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const1555 bool IsTail(const ValueNode& node) const {
1556 return ContainsKey(value_lists_, node.next);
1557 }
1558
1559 // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const1560 bool IsHead(const ValueNode& node) const {
1561 return ContainsKey(value_lists_, &node);
1562 }
1563
1564 // Returns the next node in the list after 'node'. If 'node' is the
1565 // tail, then nullptr is returned.
Next(const ValueNode & node) const1566 ValueNode* Next(const ValueNode& node) const {
1567 if (IsTail(node)) {
1568 return nullptr;
1569 } else {
1570 return node.next;
1571 }
1572 }
1573
1574 // Returns the previous node in the list before 'node'. If 'node'
1575 // is the head, then nullptr is returned.
Prev(const ValueNode & node) const1576 ValueNode* Prev(const ValueNode& node) const {
1577 if (IsHead(node)) {
1578 return nullptr;
1579 } else {
1580 return node.prev;
1581 }
1582 }
1583
1584 // Splices the entire linked list with 'head' as its head right after the
1585 // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)1586 void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
1587 DCHECK(IsHead(*head));
1588 value_lists_.erase(head);
1589
1590 ValueNode* tail = head->prev;
1591 tail->next = insert_after->next;
1592 insert_after->next->prev = tail;
1593
1594 insert_after->next = head;
1595 head->prev = insert_after;
1596 }
1597
1598 enum CombineLiveRangeOption {
1599 kMergeFirstDestInSource = 1,
1600 kMergeLastSourceInDest = 2
1601 };
1602 // This function analyzes all the HloValues that have been grouped together
1603 // with src to share a single buffer, and all the HloValues that have been
1604 // similarly grouped together with dest, to determine whether these two groups
1605 // can be combined, by removing the operation in dest, which makes a copy of
1606 // the buffer in src.
ValuesInterfere(const ValueNode * src,const ValueNode * dest,CombineLiveRangeOption merge_location)1607 bool ValuesInterfere(const ValueNode* src, const ValueNode* dest,
1608 CombineLiveRangeOption merge_location) {
1609 // Get the entire range of values sharing the buffers in src and dest.
1610 auto src_live_range = ComputeLiveRangeRegions(src);
1611 auto dest_live_range = ComputeLiveRangeRegions(dest);
1612 ComputeRelativeLocation relative_location_analysis(ordering_);
1613 auto rel1 =
1614 relative_location_analysis.Compute(src_live_range, dest_live_range);
1615 VLOG(3) << "Location of dest in relation to src:" << rel1.ToString()
1616 << " with interception set to " << rel1.InterceptDefUse() << "\n";
1617 auto rel2 =
1618 relative_location_analysis.Compute(dest_live_range, src_live_range);
1619 VLOG(3) << "Location of src in relation to dest:" << rel2.ToString()
1620 << " with interception set to " << rel1.InterceptDefUse() << "\n";
1621 // If src and dest are interleaved with each other, they interfere.
1622 if (rel1.RuntimeOrderOverlap() && rel2.RuntimeOrderOverlap()) {
1623 VLOG(3) << "Both relations are overlap.\n";
1624 return true;
1625 }
1626 // If src and dest belong to the same group of computations and do not
1627 // overlap, they do not interfere.
1628 if (rel1.RuntimeOrderOverlap() || rel2.RuntimeOrderOverlap()) {
1629 VLOG(3) << "At least one relation is overlap.\n";
1630 if (rel1.RuntimeOrderOverlap()) {
1631 VLOG(3) << "rel1 is overlap, with interception = "
1632 << rel1.InterceptDefUse() << "\n";
1633 if (rel1.InterceptDefUse() ||
1634 (merge_location != kMergeFirstDestInSource &&
1635 rel2.InterceptDefUse())) {
1636 return true;
1637 }
1638 } else {
1639 VLOG(3) << "rel2 is overlap, with interception = "
1640 << rel2.InterceptDefUse() << "\n";
1641 // Here src is at the end of a nested computation inside dest.
1642 if (rel2.InterceptDefUse() ||
1643 (merge_location != kMergeLastSourceInDest &&
1644 rel1.InterceptDefUse())) {
1645 return true;
1646 }
1647 }
1648 }
1649 if (relative_location_analysis.AddControlDependenceForUnorderedOps()) {
1650 return false;
1651 } else {
1652 // Disallow removing of copy if control deps cannot be added.
1653 return true;
1654 }
1655 }
1656
1657 // return the sequence of HloValues starting from element.
1658 // If element is not head, traverse from element to tail, then wrap around.
1659 // The ordering is important for live range region analysis.
ForEachValueInRange(const ValueNode * element,std::function<void (const ValueNode *)> visitor)1660 void ForEachValueInRange(const ValueNode* element,
1661 std::function<void(const ValueNode*)> visitor) {
1662 const ValueNode* head = element;
1663 std::vector<const ValueNode*> values;
1664 for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
1665 visitor(p);
1666 }
1667 while (!IsHead(*head)) {
1668 head = Prev(*head);
1669 }
1670 for (const ValueNode* p = head; p != element; p = Next(*p)) {
1671 visitor(p);
1672 }
1673 }
1674
ValueListToString(const ValueNode * element)1675 string ValueListToString(const ValueNode* element) {
1676 std::string result = "{";
1677 auto VisitValueNode = [&](const ValueNode* node) {
1678 if (result == "{") {
1679 result = node->value->ToShortString();
1680 } else {
1681 StrAppend(&result, ", ");
1682 StrAppend(&result, node->value->ToShortString());
1683 }
1684 };
1685 VisitValueNode(element);
1686 StrAppend(&result, "}");
1687 return result;
1688 }
1689
ToString() const1690 string ToString() const {
1691 string out = absl::StrCat("CopyRemover:\n");
1692 StrAppend(&out, " Def-use chains in each buffer:\n");
1693 for (const ValueNode* head : value_lists_) {
1694 StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
1695 ":\n");
1696 const ValueNode* p = head;
1697 do {
1698 StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
1699 absl::StrJoin(p->uses, "; ",
1700 [](string* s, const HloUse* use) {
1701 StrAppend(s, use->ToString());
1702 }),
1703 "\n");
1704
1705 p = p->next;
1706 } while (p != head);
1707 }
1708 StrAppend(&out, " Potentially removable copies:\n");
1709 for (const auto& pair : copy_map_) {
1710 const HloInstruction* copy = pair.first;
1711 const CopyNodes& copy_info = pair.second;
1712
1713 StrAppend(&out, " ", copy->name(), " : ",
1714 copy_info.src->value->ToShortString(), " => ",
1715 copy_info.dest->value->ToShortString(), "\n");
1716 }
1717 return out;
1718 }
1719
1720 private:
1721 const HloDataflowAnalysis& dataflow_;
1722 HloOrdering* ordering_;
1723
1724 // The heads of all the value lists. Each value list represents the HLO
1725 // values contained in a particular HLO buffer. The values in the list are
1726 // in dependency order.
1727 absl::flat_hash_set<const ValueNode*> value_lists_;
1728
1729 // Copy removal requires fast access to the value list elements
1730 // corresponding to the source and destination values of the kCopy
1731 // instruction. This data structure holds pointers to these elements for
1732 // each kCopy instruction in the graph.
1733 struct CopyNodes {
1734 // The source and destinations values of the kCopy instruction.
1735 ValueNode* src = nullptr;
1736 ValueNode* dest = nullptr;
1737 };
1738 absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
1739 };
1740
1741 } // namespace
1742
1743 // We add copies for all non-phi indices of the true and false computation
1744 // roots, in order to resolve interference. We later rely on
1745 // RemoveUnnecessaryCopies to drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)1746 Status CopyInsertion::AddCopiesForConditional(
1747 const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
1748 VLOG(2) << "Adding copies for kConditional instruction "
1749 << conditional->name();
1750 ShapeTree<bool> indices_to_copy(conditional->shape());
1751 TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
1752 if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
1753 conditional, &indices_to_copy)) {
1754 VLOG(2) << "No copies necessary for kWhile instruction "
1755 << conditional->name();
1756 return Status::OK();
1757 }
1758
1759 for (HloComputation* computation : conditional->branch_computations()) {
1760 HloInstruction* root = computation->root_instruction();
1761 std::vector<HloInstruction*> users = root->users();
1762 TF_ASSIGN_OR_RETURN(
1763 HloInstruction * deep_copy,
1764 computation->DeepCopyInstruction(root, &indices_to_copy));
1765 for (HloInstruction* user : users) {
1766 TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
1767 }
1768 computation->set_root_instruction(deep_copy);
1769 }
1770 return Status::OK();
1771 }
1772
1773 // Add kCopy instructions to the given module to guarantee there is no
1774 // live-range interference. Generally interference can only occur around kWhile
1775 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module)1776 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
1777 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1778 HloAliasAnalysis::Run(module, can_share_buffer_));
1779
1780 for (HloComputation* computation : module->MakeNonfusionComputations()) {
1781 for (HloInstruction* instruction :
1782 computation->MakeInstructionPostOrder()) {
1783 if (instruction->opcode() == HloOpcode::kWhile) {
1784 TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
1785 } else if (instruction->opcode() == HloOpcode::kConditional) {
1786 TF_RETURN_IF_ERROR(
1787 AddCopiesForConditional(*alias_analysis, instruction));
1788 } else {
1789 // When an operand is a tuple, we avoid copying the operand multiple
1790 // times by recording and checking the operand number of operands that
1791 // have been copied.
1792 absl::flat_hash_set<int64> copied_operands;
1793 for (const auto& operand_and_output_index :
1794 HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
1795 const HloUse& operand = operand_and_output_index.first;
1796 if (copied_operands.contains(operand.operand_number)) {
1797 continue;
1798 }
1799 copied_operands.insert(operand.operand_number);
1800 TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
1801 *alias_analysis, instruction, operand.operand_number));
1802 }
1803 }
1804 }
1805 }
1806
1807 TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
1808 return Status::OK();
1809 }
1810
AddSpecialCaseCopies(HloModule * module)1811 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
1812 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1813 return AddSpecialCaseCopies(*call_graph, module);
1814 }
1815
AddSpecialCaseCopies(const CallGraph & call_graph,HloModule * module)1816 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
1817 HloModule* module) {
1818 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1819 HloAliasAnalysis::Run(module, can_share_buffer_));
1820
1821 // Identify which shape indices of which instructions need to be copied. Store
1822 // these results in 'instructions_to_copy'.
1823 HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
1824 auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
1825 const ShapeIndex& index) {
1826 auto it = instructions_to_copy.find(instruction);
1827 if (it == instructions_to_copy.end()) {
1828 auto it_added = instructions_to_copy.emplace(
1829 std::piecewise_construct, std::forward_as_tuple(instruction),
1830 std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1831 it = it_added.first;
1832 }
1833 *it->second.mutable_element(index) = true;
1834 };
1835
1836 // Iterate through values of all constants and entry parameters. These values
1837 // are special because they are held in read-only buffers. If any of these
1838 // values share a buffer with other values (for example, the init value of a
1839 // while is a constant) then copy the value at its definition and replace all
1840 // its uses with the copy.
1841 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1842 if (ValueIsReadOnly(*value) &&
1843 alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
1844 VLOG(2) << "Value " << value->ToShortString()
1845 << " is read only, but its buffer contains more than one value. "
1846 "Copying.";
1847 add_index_to_copy(value->defining_instruction(), value->defining_index());
1848 }
1849 }
1850
1851 // Identify copies which must be added at root instructions
1852 for (HloComputation* computation : module->computations()) {
1853 const CallGraphNode& node = call_graph.GetNode(computation);
1854 if (node.context() == CallContext::kParallel) {
1855 continue;
1856 }
1857 TF_RET_CHECK(node.context() == CallContext::kSequential);
1858
1859 SpecialCaseCopyPolicy policy =
1860 GetSpecialCaseCopyPolicy(node, module, computation);
1861 HloInstruction* root = computation->root_instruction();
1862
1863 // Mark nondistinct/ambiguous indices.
1864 absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
1865 ShapeUtil::ForEachSubshape(
1866 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1867 std::vector<const HloBuffer*> buffers_at_index =
1868 alias_analysis->ComputeBuffersAt(root, index);
1869 bool buffer_seen_before = false;
1870 for (const HloBuffer* buffer : buffers_at_index) {
1871 buffer_seen_before |= !seen.emplace(buffer, index).second;
1872 }
1873
1874 if (buffer_seen_before && policy.copy_root_replicated_buffers &&
1875 computation == module->entry_computation() &&
1876 module->input_output_alias_config().OutputHasAlias(index) &&
1877 buffers_at_index.size() == 1) {
1878 absl::optional<HloInputOutputAliasConfig::Alias> alias =
1879 module->input_output_alias_config().GetAliasedParameter(index);
1880 CHECK(alias) << "Alias does not exist";
1881 const ShapeIndex& other_index = seen[buffers_at_index[0]];
1882 VLOG(2) << "Output indices " << index.ToString() << " and "
1883 << other_index.ToString() << " are both aliased to "
1884 << alias->parameter_number << " copying " << other_index;
1885 add_index_to_copy(root, other_index);
1886 return;
1887 }
1888
1889 if (buffers_at_index.size() > 1 ||
1890 (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1891 VLOG(2) << "Index " << index << " of computation "
1892 << computation->name() << " (" << root->name()
1893 << ") has ambiguous or non-distinct buffer. Copying.";
1894 add_index_to_copy(root, index);
1895 }
1896 });
1897
1898 for (const auto& pair :
1899 alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1900 const ShapeIndex& index = pair.first;
1901 const HloValueSet& value_set = pair.second;
1902 for (const HloValue* value : value_set.values()) {
1903 if (ShouldCopyRootValue(*value, policy)) {
1904 VLOG(2) << "Root of (" << root->name() << ") of computation("
1905 << computation->name()
1906 << ") has constant or parameter value at index " << index
1907 << ". Copying.";
1908 add_index_to_copy(root, index);
1909 }
1910 }
1911 }
1912 }
1913
1914 // Add copy instructions indicated in 'instructions_to_copy' to the module.
1915 for (const auto& pair : instructions_to_copy) {
1916 HloInstruction* instruction = pair.first;
1917 const ShapeTree<bool>& indices_to_copy = pair.second;
1918
1919 ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1920 std::vector<HloInstruction*> users = instruction->users();
1921 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1922 instruction->parent()->DeepCopyInstruction(
1923 instruction, &indices_to_copy, &copies_added));
1924 for (HloInstruction* user : users) {
1925 TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
1926 }
1927 if (instruction == instruction->parent()->root_instruction()) {
1928 instruction->parent()->set_root_instruction(deep_copy);
1929 }
1930 }
1931 return Status::OK();
1932 }
1933
GetNumExistingCopies(const HloModule * module)1934 static int64 GetNumExistingCopies(const HloModule* module) {
1935 int64_t num_existing_copies = 0;
1936 for (HloComputation* computation : module->computations()) {
1937 for (HloInstruction* instruction : computation->instructions()) {
1938 if (instruction->opcode() == HloOpcode::kCopy) {
1939 ++num_existing_copies;
1940 }
1941 }
1942 }
1943 return num_existing_copies;
1944 }
1945
RemoveUnnecessaryCopies(HloOrdering * ordering,HloModule * module,bool check_live_range_ordering)1946 Status CopyInsertion::RemoveUnnecessaryCopies(HloOrdering* ordering,
1947 HloModule* module,
1948 bool check_live_range_ordering) {
1949 XLA_VLOG_LINES(4, module->ToString());
1950 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1951 HloAliasAnalysis::Run(module, can_share_buffer_));
1952
1953 CopyRemover copy_remover(*module, *alias_analysis, ordering,
1954 check_live_range_ordering);
1955 if (VLOG_IS_ON(3)) {
1956 LOG(INFO) << "Removing unnecessary copies in " << module->name();
1957 LOG(INFO) << "Buffer values, in dependency order: ";
1958 for (const HloBuffer& buffer : alias_analysis->buffers()) {
1959 LOG(INFO) << " HloBuffer " << buffer.id();
1960 }
1961 }
1962
1963 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1964
1965 int64_t num_existing_copies = GetNumExistingCopies(module);
1966 bool changed = true;
1967 int64_t num_iterations = -1;
1968 while (changed) {
1969 CHECK_LE(++num_iterations, num_existing_copies);
1970 changed = false;
1971 VLOG(2) << "Running fixpoint iteration " << num_iterations
1972 << " of copy elision";
1973 for (HloComputation* computation : module->computations()) {
1974 VLOG(2) << "computation:" << computation->name() << "\n";
1975 for (HloInstruction* instruction : computation->instructions()) {
1976 VLOG(2) << instruction->ToString() << "\n";
1977 if (instruction->opcode() == HloOpcode::kCopy &&
1978 copy_remover.TryElideCopy(instruction,
1979 use_region_based_live_range_analysis_)) {
1980 changed = true;
1981 TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
1982 TF_RETURN_IF_ERROR(
1983 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
1984 }
1985 }
1986 }
1987 }
1988 return Status::OK();
1989 }
1990
Run(HloModule * module)1991 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
1992 // Copy insertion is performed in three steps:
1993 //
1994 // (1) Add copies conservatively to guarantee that there is no live-range
1995 // interference. This is done simplistically and usually results in more
1996 // copies than is strictly necessary.
1997 //
1998 // (2) Using a more fine-grained analysis, remove as many copies that were
1999 // added in (1) as possible while ensuring no live-range interference.
2000 //
2001 // (3) Add copies to resolve issues not related to live range interference
2002 // such as parameters and constants live out of the entry computation.
2003 //
2004 // We add copies then remove them (step (1) then (2)) rather than simply
2005 // adding only the copies that are necessary because, in general, it is
2006 // difficult to figure out the minimal set of copies to add once there is
2007 // interference. On the other hand, it is easy to determine if removing a copy
2008 // will introduce interference.
2009 //
2010 // The final copy insertion in (3) is done separately to simplify the
2011 // implementation of copy removal in (2) which is the most complicated part of
2012 // the pass. As is, copy removal only has to reason about live range
2013 // interference. If all copies were added in step (1) then copy removal would
2014 // also have to reason about things like constants and parameters live out of
2015 // the computation.
2016 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
2017 if (!call_graph->IsFlattened()) {
2018 return FailedPrecondition(
2019 "Call graph must be flattened before copy insertion.");
2020 }
2021
2022 TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
2023
2024 // Simplify the tuple structures introduced by the deep copies. This should be
2025 // done before removing copies (RemoveUnnecessaryCopies) because tuple
2026 // simplification changes dependencies in the graph which changes live range
2027 // interference in the graph. Also run DCE to remove the dead Tuple/GTE
2028 // instructions introduced by tuple simplification.
2029 TupleSimplifier tuple_simplifier;
2030 HloDCE dce;
2031 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2032 TF_RETURN_IF_ERROR(dce.Run(module).status());
2033 DumpHloModuleDuringPassIfEnabled(
2034 name(), "after adding copies to resolve interference", *module);
2035
2036 DependencyHloOrdering ordering(module);
2037 TF_RETURN_IF_ERROR(
2038 RemoveUnnecessaryCopies(&ordering, module,
2039 /*check_live_range_ordering=*/true));
2040 DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
2041 *module);
2042 TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
2043 DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
2044 *module);
2045
2046 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2047 TF_RETURN_IF_ERROR(dce.Run(module).status());
2048
2049 if (VLOG_IS_ON(1)) {
2050 int64_t num_total_copies = 0;
2051 for (HloComputation* computation : module->computations()) {
2052 for (HloInstruction* instruction : computation->instructions()) {
2053 if (instruction->opcode() == HloOpcode::kCopy) {
2054 num_total_copies++;
2055 }
2056 }
2057 }
2058 VLOG(1) << "Num copies before copy-insertion: "
2059 << GetNumExistingCopies(module);
2060 VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
2061 }
2062
2063 return true;
2064 }
2065 } // namespace xla
2066