1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48
49 using absl::StrCat;
50
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52 HloInstruction* root_instruction) {
53 int parameter_count = 0;
54 for (auto& instruction : instructions_) {
55 if (instruction->opcode() == HloOpcode::kParameter) {
56 parameter_count++;
57 }
58 }
59 // If root_instruction is not specified use the last added instruction.
60 HloInstruction* root =
61 root_instruction ? root_instruction : last_added_instruction_;
62 CHECK_NE(nullptr, root);
63 return absl::WrapUnique(new HloComputation(
64 name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68 const string& name, int parameter_count,
69 std::vector<std::unique_ptr<HloInstruction>>* instructions,
70 HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71 : name_(NameUniquer::GetSanitizedName(name)),
72 unique_id_(-1),
73 root_instruction_(root_instruction),
74 fusion_instruction_(fusion_instruction) {
75 param_instructions_.resize(parameter_count, nullptr);
76 bool root_found = false;
77 for (auto& instruction : *instructions) {
78 if (instruction->opcode() == HloOpcode::kParameter) {
79 int64 param_no = instruction->parameter_number();
80 CHECK(param_no >= 0 && param_no < parameter_count)
81 << "\nERROR: invalid parameter number. Expected [0, "
82 << parameter_count << "), got " << param_no;
83 CHECK(param_instructions_[param_no] == nullptr)
84 << "\nERROR: parameter number " << param_no
85 << " already allocated in this computation";
86 param_instructions_[param_no] = instruction.get();
87 }
88 root_found |= instruction.get() == root_instruction_;
89 AddInstructionInternal(std::move(instruction));
90 }
91 CHECK(root_found)
92 << "\nERROR: root instruction is not present in computation.";
93 }
94
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)95 HloInstruction* HloComputation::AddInstruction(
96 std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
97 CHECK(instruction->opcode() != HloOpcode::kParameter)
98 << "Parameter instructions cannot be added to a computation after "
99 << "it has been built";
100 if (!new_name.empty()) {
101 instruction->SetAndSanitizeName(new_name);
102 }
103 return AddInstructionInternal(std::move(instruction));
104 }
105
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)106 HloInstruction* HloComputation::AddInstructionInternal(
107 std::unique_ptr<HloInstruction> instruction) {
108 if (parent() != nullptr) {
109 instruction->UniquifyName(&parent()->instruction_name_uniquer());
110 instruction->SetUniqueId(parent()->NewUniqueInstructionId());
111 }
112 instruction->set_parent(this);
113 HloInstruction* pinst = instruction.get();
114 instruction_iterators_[pinst] =
115 instructions_.insert(instructions_.end(), std::move(instruction));
116 return pinst;
117 }
118
AddParameter(std::unique_ptr<HloInstruction> instruction)119 HloInstruction* HloComputation::AddParameter(
120 std::unique_ptr<HloInstruction> instruction) {
121 CHECK(instruction->opcode() == HloOpcode::kParameter);
122 CHECK(IsFusionComputation());
123 CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
124 instruction->set_parent(this);
125 param_instructions_.push_back(instruction.get());
126 AddInstructionInternal(std::move(instruction));
127 return instructions_.back().get();
128 }
129
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)130 HloInstruction* HloComputation::AddEntryComputationParameter(
131 std::unique_ptr<HloInstruction> instruction) {
132 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
133 CHECK_EQ(instruction->parameter_number(), num_parameters());
134 CHECK(parent()->entry_computation() == this);
135
136 HloModuleConfig config = parent()->config();
137 config.mutable_entry_computation_layout()->add_parameter_layout(
138 ShapeLayout(instruction->shape()));
139 parent()->set_config(config);
140
141 instruction->set_parent(this);
142 param_instructions_.push_back(instruction.get());
143 AddInstructionInternal(std::move(instruction));
144
145 return instructions_.back().get();
146 }
147
ReplaceEntryComputationParameter(int64 param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)148 Status HloComputation::ReplaceEntryComputationParameter(
149 int64 param_no, HloInstruction* old_instruction,
150 std::unique_ptr<HloInstruction> instruction) {
151 CHECK_GE(param_no, 0);
152 CHECK_LT(param_no, param_instructions_.size());
153 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
154 CHECK(parent()->entry_computation() == this);
155
156 HloModuleConfig config = parent()->config();
157 *config.mutable_entry_computation_layout()->mutable_parameter_layout(
158 param_no) = ShapeLayout(instruction->shape());
159 parent()->set_config(config);
160
161 instruction->set_parent(this);
162 param_instructions_[param_no] = instruction.get();
163 AddInstructionInternal(std::move(instruction));
164
165 return ForceRemoveInstruction(old_instruction);
166 }
167
RemoveParameter(int64 param_no)168 Status HloComputation::RemoveParameter(int64 param_no) {
169 CHECK_GE(param_no, 0);
170 CHECK_LT(param_no, param_instructions_.size());
171 CHECK(IsFusionComputation());
172 HloInstruction* param_instruction = param_instructions_[param_no];
173 auto param_instruction_iterator = param_instructions_.begin() + param_no;
174 param_instructions_.erase(param_instruction_iterator);
175 // Throw removed fused parameter instruction away.
176 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
177
178 while (param_no < param_instructions_.size()) {
179 param_instruction = param_instructions_[param_no];
180 HloInstruction* new_instr =
181 AddInstructionInternal(HloInstruction::CreateParameter(
182 param_no, param_instruction->shape(), StrCat("param_", param_no)));
183 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
184 param_instructions_[param_no] = new_instr;
185 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
186 param_no++;
187 }
188
189 return Status::OK();
190 }
191
RemoveUnusedParametersFromFusedComputation()192 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
193 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
194 }
195
RemoveUnusedParametersFromAnyComputation()196 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
197 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
198 }
199
RemoveUnusedParametersImpl(bool allow_non_fusion)200 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
201 CHECK(allow_non_fusion || IsFusionComputation());
202 int64 removed = 0;
203 for (int64 i = 0; i < param_instructions_.size(); ++i) {
204 HloInstruction* param_instruction = param_instructions_[i];
205 if (param_instruction->user_count() == 0 &&
206 param_instruction != root_instruction()) {
207 TF_RETURN_IF_ERROR(
208 RemoveInstructionImpl(param_instruction, allow_non_fusion));
209 ++removed;
210 continue;
211 }
212
213 if (removed > 0) {
214 const int64 param_no = i - removed;
215 HloInstruction* new_instr = AddInstructionInternal(
216 HloInstruction::CreateParameter(param_no, param_instruction->shape(),
217 StrCat("param_", param_no)));
218 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
219 param_instructions_[param_no] = new_instr;
220 TF_RETURN_IF_ERROR(
221 RemoveInstructionImpl(param_instruction, allow_non_fusion));
222 }
223 }
224 param_instructions_.resize(param_instructions_.size() - removed);
225 return Status::OK();
226 }
227
IsSafelyRemovable(const HloInstruction * instruction)228 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
229 // If the instruction has control predecessors or successors then we cannot
230 // remove the instruction without violating ordering constraints (added, for
231 // example, to avert interference due to buffer aliasing).
232 if (!instruction->control_predecessors().empty() ||
233 !instruction->control_successors().empty()) {
234 return false;
235 }
236
237 if (instruction->opcode() == HloOpcode::kParameter &&
238 !IsFusionComputation()) {
239 return false;
240 }
241
242 return true;
243 }
244
HasSideEffect() const245 bool HloComputation::HasSideEffect() const {
246 for (auto* instruction : instructions()) {
247 if (instruction->HasSideEffect()) {
248 return true;
249 }
250 }
251 return false;
252 }
253
IsMarkedAsDead(const HloInstruction * inst)254 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
255 return inst->IsMarkedAsDead();
256 }
257
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)258 Status HloComputation::RemoveInstructionAndUnusedOperands(
259 HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
260 TF_RET_CHECK(root_instruction() != instruction);
261
262 TF_RET_CHECK(instruction->user_count() == 0);
263 TF_RET_CHECK(IsSafelyRemovable(instruction))
264 << "Cannot remove instruction: " << instruction->ToString();
265 absl::flat_hash_set<HloInstruction*> removed;
266 std::queue<HloInstruction*> worklist;
267 worklist.push(instruction);
268 while (!worklist.empty()) {
269 HloInstruction* item = worklist.front();
270 worklist.pop();
271
272 if (removed.contains(item) || item->user_count() != 0 ||
273 item == root_instruction() || !IsSafelyRemovable(item) ||
274 (item->HasSideEffect() && item != instruction)) {
275 continue;
276 }
277 for (int i = 0; i < item->operand_count(); ++i) {
278 worklist.push(item->mutable_operand(i));
279 }
280
281 if (cleanup) {
282 cleanup(item);
283 }
284 TF_RETURN_IF_ERROR(RemoveInstruction(item));
285 removed.insert(item);
286 }
287 return Status::OK();
288 }
289
RemoveInstruction(HloInstruction * instruction)290 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
291 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
292 }
293
ForceRemoveInstruction(HloInstruction * instruction)294 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
295 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
296 }
297
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)298 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
299 bool ignore_safety_check) {
300 VLOG(2) << "Removing instruction " << instruction->name()
301 << " from computation " << name();
302 TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
303 << "cannot remove instruction: " << instruction->ToString();
304 TF_RET_CHECK(root_instruction() != instruction)
305 << "cannot remove root instruction " << instruction->name();
306 TF_RET_CHECK(instruction->user_count() == 0)
307 << "instruction " << instruction->name()
308 << " has users and cannot be removed";
309 TF_RET_CHECK(instruction->control_predecessors().empty())
310 << "instruction " << instruction->name()
311 << " has control predecessors and cannot be removed";
312 TF_RET_CHECK(instruction->control_successors().empty())
313 << "instruction " << instruction->name()
314 << " has control successors and cannot be removed";
315
316 auto inst_it = instruction_iterators_.find(instruction);
317 TF_RET_CHECK(inst_it != instruction_iterators_.end());
318 (*inst_it->second)->set_parent(nullptr);
319 to_be_deleted_.emplace_back(inst_it->second->release());
320 to_be_deleted_.back()->DetachFromOperandsAndUsers();
321 // Clear all operands to avoid Null operands.
322 to_be_deleted_.back()->RemoveAllOperands();
323 to_be_deleted_.back()->ClearCalledComputations();
324 to_be_deleted_.back()->MarkAsDead();
325 instructions_.erase(inst_it->second);
326 instruction_iterators_.erase(inst_it);
327 return Status::OK();
328 }
329
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)330 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
331 bool accept_different_shape) {
332 // The shape of the root (ignoring layout) is an invariant of the computation
333 // for non-fusion cases.
334 if (!IsFusionComputation() && !accept_different_shape) {
335 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
336 root_instruction_->shape()))
337 << new_root_instruction->shape() << " is incompatible with "
338 << root_instruction_->shape();
339 }
340 bool root_found = false;
341 for (auto& instruction : instructions_) {
342 if (new_root_instruction == instruction.get()) {
343 root_found = true;
344 break;
345 }
346 }
347 DCHECK(root_found);
348
349 if (parent() && parent()->has_entry_computation() &&
350 parent()->entry_computation() == this) {
351 if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
352 root_instruction_->shape())) {
353 // Rebuild input output alias config now that we have a new output shape.
354 parent()->input_output_alias_config() =
355 HloInputOutputAliasConfig(new_root_instruction->shape());
356 }
357 }
358
359 root_instruction_ = new_root_instruction;
360 }
361
362 namespace {
363
364 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)365 void ComputeComputationPostOrder(HloComputation* computation,
366 absl::flat_hash_set<HloComputation*>* visited,
367 std::vector<HloComputation*>* post_order) {
368 if (visited->insert(computation).second) {
369 for (auto* instruction : computation->instructions()) {
370 for (HloComputation* called_computation :
371 instruction->called_computations()) {
372 ComputeComputationPostOrder(called_computation, visited, post_order);
373 }
374 }
375 post_order->push_back(computation);
376 }
377 }
378
379 } // namespace
380
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const381 void HloComputation::ComputeInstructionPostOrder(
382 const HloComputation::ChannelDependencyGroup& channel_dependency_group,
383 std::vector<HloInstruction*>* post_order, HloInstruction* root,
384 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
385 std::vector<HloInstruction*> dfs_stack;
386 dfs_stack.push_back(root);
387 while (!dfs_stack.empty()) {
388 const auto current = dfs_stack.back();
389 CHECK_EQ(current->parent(), this)
390 << "Instruction " << current->name()
391 << " is not in the current computation (" << name() << ").";
392 auto it = visited->find(current);
393 if (it != visited->end()) {
394 if (it->second == kVisited) {
395 // Already visited.
396 dfs_stack.pop_back();
397 continue;
398 }
399 // Visit this node.
400 CHECK_EQ(kVisiting, it->second);
401 dfs_stack.pop_back();
402 post_order->push_back(current);
403 it->second = kVisited;
404 continue;
405 }
406
407 visited->insert({current, kVisiting});
408
409 const auto get_channel_id =
410 [](HloInstruction* inst) -> absl::optional<int64> {
411 switch (inst->opcode()) {
412 case HloOpcode::kRecvDone:
413 return inst->channel_id();
414 case HloOpcode::kAllReduce:
415 return inst->channel_id();
416 default:
417 return absl::nullopt;
418 }
419 };
420
421 // When adding a predecessor to the dfs_stack, we need to also add its
422 // associated channel dependencies.
423 const auto add_dfs_stack = [&](HloInstruction* inst) {
424 auto channel_id = get_channel_id(inst);
425 if (channel_id && channel_dependency_group.count(*channel_id)) {
426 auto it = channel_dependency_group.find(*channel_id);
427 for (HloInstruction* cinst : it->second) {
428 dfs_stack.emplace_back(cinst);
429 }
430 } else {
431 dfs_stack.emplace_back(inst);
432 }
433 };
434
435 const auto add_predecessors = [&](HloInstruction* inst) {
436 // Add the operands to the stack in reverse order so the first operand is
437 // processed first. This will produce a more natural ordering and a nicer
438 // result for things like HLO stringification.
439 const auto& operands = inst->operands();
440 for (int64 i = operands.size() - 1; i >= 0; --i) {
441 add_dfs_stack(operands[i]);
442 }
443
444 for (HloInstruction* op : inst->control_predecessors()) {
445 add_dfs_stack(op);
446 }
447 };
448
449 // If the current instruction is a channel instruction, add the dependencies
450 // from all associated instructions of the channel.
451 auto channel_id = get_channel_id(current);
452 if (channel_id && channel_dependency_group.count(*channel_id)) {
453 auto it = channel_dependency_group.find(*channel_id);
454 for (HloInstruction* cinst : it->second) {
455 add_predecessors(cinst);
456 }
457 } else {
458 add_predecessors(current);
459 }
460 }
461 }
462
463 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const464 HloComputation::ComputeChannelDependencies() const {
465 ChannelDependencyGroup channel_dependency_group;
466 for (const auto& instruction : instructions_) {
467 switch (instruction->opcode()) {
468 case HloOpcode::kSend:
469 case HloOpcode::kRecvDone:
470 case HloOpcode::kAllReduce: {
471 auto channel_id = instruction->channel_id();
472 if (channel_id) {
473 channel_dependency_group[channel_id.value()].push_back(
474 instruction.get());
475 }
476 break;
477 }
478 default:
479 break;
480 }
481 }
482 return channel_dependency_group;
483 }
484
HasOnlyTraceUsers(const HloInstruction * instruction)485 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
486 return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
487 return user->opcode() == HloOpcode::kTrace;
488 });
489 }
490
MakeInstructionPostOrder() const491 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
492 auto channel_dependency_group = ComputeChannelDependencies();
493 std::vector<HloInstruction*> post_order;
494 post_order.reserve(instruction_count());
495 std::vector<HloInstruction*> trace_instructions;
496 absl::flat_hash_map<HloInstruction*, VisitState> visited;
497 visited.reserve(instruction_count());
498 for (auto& instruction : instructions_) {
499 if (instruction->opcode() == HloOpcode::kTrace) {
500 // Trace instructions aren't handled by the DFS visitor. Add trace
501 // instructions to the post order at the end (necessarily they have no
502 // users).
503 trace_instructions.push_back(instruction.get());
504 } else if (HasOnlyTraceUsers(instruction.get())) {
505 ComputeInstructionPostOrder(channel_dependency_group, &post_order,
506 instruction.get(), &visited);
507 }
508 }
509 post_order.insert(post_order.end(), trace_instructions.begin(),
510 trace_instructions.end());
511 CHECK_EQ(instructions_.size(), post_order.size())
512 << "number of instructions does not match post order size";
513 return post_order;
514 }
515
MakeEmbeddedComputationsList() const516 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
517 const {
518 absl::flat_hash_set<HloComputation*> visited;
519 std::vector<HloComputation*> post_order;
520
521 // To avoid special handling of this computation, cast away const of
522 // 'this'. 'this' is immediately removed from the post order after
523 // construction.
524 //
525 // TODO(b/78350259): This violates const-correctness, since while the original
526 // computation is not returned, we still retrieve non-const computations from
527 // a const one. Consider also avoiding const for HloComputation, or review XLA
528 // for const-correctness of non-HloInstruction* types like this.
529 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
530 &post_order);
531
532 // We don't want to include this computation in the post order.
533 CHECK_EQ(this, post_order.back());
534 post_order.pop_back();
535
536 return post_order;
537 }
538
ToString(const HloPrintOptions & options) const539 string HloComputation::ToString(const HloPrintOptions& options) const {
540 return ToString(options, MakeInstructionPostOrder());
541 }
542
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const543 string HloComputation::ToString(
544 const HloPrintOptions& options,
545 absl::Span<const HloInstruction* const> instruction_order) const {
546 CHECK_EQ(instruction_order.size(), instruction_count());
547
548 const string tab(2 * options.indent_amount(), ' ');
549
550 std::ostringstream s;
551 s << tab;
552
553 if (!options.is_in_nested_computation()) {
554 if (options.print_percent()) {
555 s << "%";
556 }
557 if (options.print_ids()) {
558 // Exclude entry computation's name because it includes and leads to
559 // non-deterministic fingerprint.
560 s << PrintName(name(), options.print_ids()) << " ";
561 }
562 }
563
564 if (options.print_program_shape()) {
565 s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
566 << " ";
567 }
568 s << "{\n";
569
570 // There are instructions which are required to be printed. Additionally, we
571 // print some instructions before and after required ones. The resulting
572 // output has the following format.
573 //
574 // computation {
575 // ...
576 // additional_instructions
577 // required_instructions
578 // additional_instructions
579 // ...
580 // additional_instructions
581 // required_instructions
582 // additional_instructions
583 // ...
584 // }
585 std::set<int> instructions_to_print;
586 {
587 // Find all the instructions that should be printed.
588 auto add_instruction = [&instructions_to_print,
589 &instruction_order](int index) {
590 if (index < 0 || index >= instruction_order.size()) {
591 return;
592 }
593 instructions_to_print.insert(index);
594 };
595
596 auto add_instructions_arround = [&add_instruction, &options](int index) {
597 for (int i = index - options.leading_and_trailing_instructions_number();
598 i <= index + options.leading_and_trailing_instructions_number();
599 ++i) {
600 add_instruction(i);
601 }
602 };
603
604 for (int i = 0; i < instruction_order.size(); ++i) {
605 const HloInstruction* instruction = instruction_order[i];
606 CHECK_EQ(this, instruction->parent());
607 if (options.print_instruction(instruction)) {
608 add_instructions_arround(i);
609 }
610 }
611 }
612
613 {
614 // Print the instructions in this computation.
615 HloPrintOptions new_options = options;
616 new_options.set_indent_amount(options.indent_amount() + 1)
617 .set_is_in_nested_computation(true);
618
619 const string new_tab(2 * new_options.indent_amount(), ' ');
620
621 CanonicalNameMap name_map;
622
623 bool print_prev = true;
624 for (int index = 0; index < instruction_order.size(); ++index) {
625 const HloInstruction* instruction = instruction_order[index];
626 if (instructions_to_print.find(index) != instructions_to_print.end()) {
627 s << new_options.format_instruction(
628 instruction,
629 instruction->ToStringWithCanonicalNameMap(new_options,
630 &name_map),
631 new_options.indent_amount(), instruction == root_instruction_)
632 << "\n";
633 print_prev = true;
634 } else if (print_prev) {
635 s << new_tab << "...\n";
636 print_prev = false;
637 }
638 }
639 }
640
641 s << tab << "}";
642 return s.str();
643 }
644
ToProto() const645 HloComputationProto HloComputation::ToProto() const {
646 HloComputationProto proto;
647 CHECK(unique_id_ != -1)
648 << "This computation does not have a valid id. Please make sure the "
649 "computation is inside a module before dumping it.";
650 proto.set_id(unique_id_);
651 proto.set_name(name_);
652 for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
653 HloInstructionProto instruction_proto = instruction->ToProto();
654 proto.add_instructions()->Swap(&instruction_proto);
655 }
656 proto.set_root_id(root_instruction()->unique_id());
657 *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
658 return proto;
659 }
660
661 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)662 HloComputation::CreateFromProto(
663 const HloComputationProto& proto,
664 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
665 bool prohibit_empty_literal) {
666 absl::flat_hash_map<int64, HloInstruction*> instruction_map;
667 absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
668 std::vector<std::unique_ptr<HloInstruction>> instructions;
669 int64 parameter_count = 0;
670 for (const HloInstructionProto& instruction_proto : proto.instructions()) {
671 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
672 HloInstruction::CreateFromProto(
673 instruction_proto, instruction_map, computation_map,
674 prohibit_empty_literal));
675 if (instruction->opcode() == HloOpcode::kParameter) {
676 parameter_count++;
677 }
678 TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
679 instruction_map[instruction_proto.id()] = instruction.get();
680 to_proto_id[instruction.get()] = instruction_proto.id();
681 instructions.push_back(std::move(instruction));
682 }
683
684 TF_RET_CHECK(proto.root_id() != -1);
685 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
686 HloInstruction* root = instruction_map.at(proto.root_id());
687
688 // Sort the instructions in the proto id's order.
689 absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
690 const std::unique_ptr<HloInstruction>& b) {
691 return to_proto_id[a.get()] < to_proto_id[b.get()];
692 });
693
694 TF_RETURN_IF_ERROR([&]() -> Status {
695 std::vector<bool> parameters_seen(parameter_count);
696 int parameters_seen_count = 0;
697 for (auto& instruction : instructions) {
698 if (instruction->opcode() == HloOpcode::kParameter) {
699 int64 param_no = instruction->parameter_number();
700 TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
701 << "Invalid parameter number. Expected [0, " << parameter_count
702 << "), got " << param_no;
703 TF_RET_CHECK(!parameters_seen[param_no])
704 << "Parameter number " << param_no
705 << " already allocated in this computation";
706 parameters_seen[param_no] = true;
707 parameters_seen_count++;
708 }
709 }
710 TF_RET_CHECK(parameters_seen_count == parameter_count)
711 << "Not all parameters in range [0, " << parameter_count
712 << ") were referenced";
713 return Status::OK();
714 }());
715
716 auto computation = absl::WrapUnique(
717 new HloComputation(proto.name(), parameter_count, &instructions, root,
718 /*fusion_instruction=*/nullptr));
719 computation->unique_id_ = proto.id();
720 return std::move(computation);
721 }
722
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)723 void HloComputation::FuseInstructionsInto(
724 absl::Span<HloInstruction* const> instructions_to_fuse,
725 HloInstruction* fusion_instruction) {
726 CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
727 HloInstruction* root = instructions_to_fuse.front();
728 TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
729 if (root == root_instruction()) {
730 set_root_instruction(fusion_instruction);
731 }
732 TF_CHECK_OK(RemoveInstruction(root));
733 for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
734 HloInstruction* instruction = instructions_to_fuse[i];
735 fusion_instruction->FuseInstruction(instruction);
736 if (instruction->user_count() == 0) {
737 TF_CHECK_OK(RemoveInstruction(instruction));
738 }
739 }
740 }
741
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)742 HloInstruction* HloComputation::CreateFusionInstruction(
743 absl::Span<HloInstruction* const> instructions_to_fuse,
744 HloInstruction::FusionKind fusion_kind) {
745 HloInstruction* root = instructions_to_fuse.front();
746 HloInstruction* fusion_instruction = AddInstruction(
747 HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
748 FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
749 return fusion_instruction;
750 }
751
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)752 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
753 HloInstruction* instruction, ShapeIndex* index,
754 const std::function<
755 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
756 HloComputation* computation)>& copy_leaf) {
757 if (instruction->shape().IsTuple()) {
758 std::vector<HloInstruction*> elements;
759 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
760 i++) {
761 HloInstruction* gte =
762 AddInstruction(HloInstruction::CreateGetTupleElement(
763 ShapeUtil::GetTupleElementShape(instruction->shape(), i),
764 instruction, i));
765
766 index->push_back(i);
767 TF_ASSIGN_OR_RETURN(HloInstruction * element,
768 DeepCopyHelper(gte, index, copy_leaf));
769 elements.push_back(element);
770 index->pop_back();
771 }
772 return AddInstruction(HloInstruction::CreateTuple(elements));
773 }
774 if (instruction->shape().IsToken()) {
775 // Tokens have no on-device representation and cannot be copied. Pass
776 // through transparently.
777 return instruction;
778 }
779
780 // Array shape.
781 TF_RET_CHECK(instruction->shape().IsArray());
782 return copy_leaf(instruction, *index, this);
783 }
784
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)785 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
786 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
787 ShapeTree<HloInstruction*>* copies_added) {
788 if (instruction->parent() != this) {
789 return FailedPrecondition(
790 "Can't deep copy instruction %s: instruction is not in computation %s",
791 instruction->name(), name());
792 }
793 if (indices_to_copy != nullptr &&
794 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
795 return FailedPrecondition(
796 "Can't deep copy instruction %s: given shape tree of indices to copy "
797 "has incompatible shapes: %s vs. %s",
798 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
799 ShapeUtil::HumanString(indices_to_copy->shape()));
800 }
801
802 ShapeIndex index;
803 auto copy_leaf = [indices_to_copy, copies_added](
804 HloInstruction* leaf, const ShapeIndex& leaf_index,
805 HloComputation* computation) {
806 if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
807 HloInstruction* copy = computation->AddInstruction(
808 HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
809 if (copies_added != nullptr) {
810 *copies_added->mutable_element(leaf_index) = copy;
811 }
812 return copy;
813 }
814 // Elements which are not to be copied are passed through
815 // transparently.
816 return leaf;
817 };
818 return DeepCopyHelper(instruction, &index, copy_leaf);
819 }
820
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)821 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
822 HloInstruction* instruction,
823 const std::function<
824 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
825 HloComputation* computation)>& copy_leaf) {
826 if (instruction->parent() != this) {
827 return FailedPrecondition(
828 "Can't deep copy instruction %s: instruction is not in computation %s",
829 instruction->name(), name());
830 }
831 ShapeIndex index;
832 return DeepCopyHelper(instruction, &index, copy_leaf);
833 }
834
ComputeProgramShape(bool include_ids) const835 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
836 ProgramShape program_shape;
837
838 for (auto* param_instruction : param_instructions_) {
839 *program_shape.add_parameters() = param_instruction->shape();
840 *program_shape.add_parameter_names() =
841 PrintName(param_instruction->name(), include_ids);
842 }
843 *program_shape.mutable_result() = root_instruction_->shape();
844
845 return program_shape;
846 }
847
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values) const848 bool HloComputation::EqualInternal(const HloComputation& other,
849 bool is_layout_sensitive,
850 bool ignore_channel_id_values) const {
851 if (this == &other) {
852 return true;
853 }
854 absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
855 visited;
856 std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
857
858 worklist.push_back({root_instruction(), other.root_instruction()});
859
860 while (!worklist.empty()) {
861 auto pair = worklist.back();
862 worklist.pop_back();
863
864 if (visited.contains(pair)) {
865 continue;
866 }
867 visited.emplace(pair);
868 // TODO(b/123082518): Avoid recursively invoking Equal because it may
869 // cause a stack overflow with deeply nested subcomputations.
870 auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
871 return true;
872 };
873 auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
874 return a->EqualInternal(*b, is_layout_sensitive,
875 ignore_channel_id_values);
876 };
877 bool identical_ignoring_operands =
878 ignore_channel_id_values
879 ? pair.first->IdenticalIgnoringChannelIdValues(
880 *pair.second, operands_eq, comp_eq, is_layout_sensitive)
881 : pair.first->Identical(*pair.second, operands_eq, comp_eq,
882 is_layout_sensitive);
883 if (!identical_ignoring_operands) {
884 return false;
885 }
886 for (size_t i = 0; i < pair.first->operands().size(); ++i) {
887 worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
888 }
889 }
890 return true;
891 }
892
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)893 Status HloComputation::ReplaceWithNewInstruction(
894 HloInstruction* old_instruction,
895 std::unique_ptr<HloInstruction> new_instruction) {
896 return ReplaceInstruction(old_instruction,
897 AddInstruction(std::move(new_instruction)));
898 }
899
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)900 Status HloComputation::ReplaceWithNewEntryComputationParameter(
901 HloInstruction* old_instruction,
902 std::unique_ptr<HloInstruction> new_instruction) {
903 return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
904 std::move(new_instruction)));
905 }
906
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)907 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
908 HloInstruction* new_instruction) {
909 TF_RET_CHECK(
910 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
911 << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
912 << ShapeUtil::HumanString(new_instruction->shape());
913
914 VLOG(10) << "transformed " << old_instruction->ToString() << " to "
915 << new_instruction->ToString();
916 // Try to add metadata for HLO instructions that are created to replace
917 // existing HLO instructions (e.g. during optimizations). The assumption is
918 // that the old instruction and the new instruction would perform the same
919 // function, and that they would be correlated to the same TF op. This might
920 // not always be correct since HLO optimizations can cross TF op boundaries.
921 // But still this seems to be better than nothing.
922 bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
923 !old_instruction->metadata().op_name().empty();
924 bool overwrite_pass_id =
925 new_instruction->metadata().op_name().empty() &&
926 new_instruction->metadata().logical_creation_pass_id() == 0 &&
927 old_instruction->metadata().logical_creation_pass_id() != 0;
928 if (overwrite_op_name || overwrite_pass_id) {
929 new_instruction->set_metadata(old_instruction->metadata());
930 }
931 if (new_instruction->frontend_attributes().map().empty()) {
932 new_instruction->set_frontend_attributes(
933 old_instruction->frontend_attributes());
934 }
935
936 // Like the metadata above, if the user didn't specify any sharding
937 // information on the new instruction we should copy the old sharding
938 // information (if any).
939 if (!new_instruction->has_sharding()) {
940 new_instruction->set_sharding(old_instruction->sharding_ptr());
941 }
942
943 TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
944 return RemoveInstructionAndUnusedOperands(old_instruction);
945 }
946
CollectUnreachableRoots() const947 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
948 std::vector<HloInstruction*> unreachable_roots;
949 for (auto* instruction : instructions()) {
950 if (instruction->user_count() == 0 &&
951 instruction->control_successors().empty() &&
952 instruction != root_instruction()) {
953 unreachable_roots.push_back(instruction);
954 }
955 }
956 VLOG(3) << "Unreachable roots:"
957 << absl::StrJoin(unreachable_roots, "\n\t",
958 [](string* out, const HloInstruction* hlo) {
959 absl::StrAppend(out, hlo->ToString());
960 });
961 return unreachable_roots;
962 }
963
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const964 Status HloComputation::AcceptWithOperandOrder(
965 DfsHloVisitor* visitor,
966 const HloInstruction::CompareFunction& operand_order) const {
967 // Visit unreachable roots. Beware that the visitor might delete the currently
968 // visited root, which would invalidate iterators if the unreachable roots
969 // weren't computed ahead of time.
970 for (HloInstruction* root : CollectUnreachableRoots()) {
971 TF_RETURN_IF_ERROR(
972 root->AcceptWithOperandOrder(visitor, operand_order,
973 /*call_finish_visit=*/false));
974 }
975 // Visit the computation root instruction last.
976 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
977 /*call_finish_visit=*/true);
978 }
979
Clone(const string & suffix,HloCloneContext * context)980 std::unique_ptr<HloComputation> HloComputation::Clone(
981 const string& suffix, HloCloneContext* context) {
982 return CloneWithReplacements(
983 /*replacements=*/absl::flat_hash_map<const HloInstruction*,
984 std::unique_ptr<HloInstruction>>(),
985 /*extra_parameters=*/{}, context, suffix);
986 }
987
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)988 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
989 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
990 HloCloneContext* context, const string& suffix) {
991 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
992 replacements;
993 replacements.emplace(std::move(r1));
994 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
995 context, suffix);
996 }
997
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)998 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
999 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1000 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1001 HloCloneContext* context, const string& suffix) {
1002 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1003 replacements;
1004 replacements.emplace(std::move(r1));
1005 replacements.emplace(std::move(r2));
1006 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1007 context, suffix);
1008 }
1009
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)1010 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1011 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1012 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1013 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1014 HloCloneContext* context, const string& suffix) {
1015 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1016 replacements;
1017 replacements.emplace(std::move(r1));
1018 replacements.emplace(std::move(r2));
1019 replacements.emplace(std::move(r3));
1020 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1021 context, suffix);
1022 }
1023
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix,const HloInstruction * new_root)1024 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1025 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1026 replacements,
1027 absl::Span<const HloInstruction* const> extra_parameters,
1028 HloCloneContext* context, const string& suffix,
1029 const HloInstruction* new_root) {
1030 std::unique_ptr<HloCloneContext> context_ptr;
1031 if (context == nullptr) {
1032 context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
1033 context = context_ptr.get();
1034 }
1035 if (new_root == nullptr) {
1036 new_root = root_instruction();
1037 }
1038
1039 // Look up instr in the replacements map, and return either the replacement,
1040 // or instr, if the replacement isn't present.
1041 //
1042 // Note: This can return null, indicating that instr should not be present in
1043 // the new computation.
1044 auto replace = [&](const HloInstruction* instr) {
1045 auto it = replacements.find(instr);
1046 return it != replacements.end() ? it->second.get() : instr;
1047 };
1048
1049 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1050
1051 // We want to do a postorder walk over [replace(i) for i in instructions_].
1052 // We can't reuse MakeInstructionPostOrder() for this, because that will
1053 // generate a postorder of plain instructions_, and our replacements may
1054 // change the postorder!
1055 //
1056 // The postorder we want here is simpler than what MakeInstructionPostOrder()
1057 // does -- we only care about operand dependencies -- so let's just do it
1058 // ourselves.
1059 std::vector<const HloInstruction*> postorder;
1060 absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1061 for (const auto& instr : instructions_) {
1062 std::vector<const HloInstruction*> dfs_stack;
1063 const HloInstruction* new_instr = replace(instr.get());
1064 if (!new_instr) {
1065 continue;
1066 }
1067 dfs_stack.push_back(new_instr);
1068
1069 while (!dfs_stack.empty()) {
1070 auto* cur = dfs_stack.back();
1071 auto it = visited.find(cur);
1072 if (it != visited.end()) {
1073 dfs_stack.pop_back();
1074 if (it->second == kVisited) {
1075 continue;
1076 }
1077 CHECK_EQ(it->second, kVisiting);
1078 postorder.push_back(cur);
1079 it->second = kVisited;
1080 continue;
1081 }
1082
1083 visited.insert({cur, kVisiting});
1084 for (HloInstruction* operand : cur->operands()) {
1085 const HloInstruction* new_operand = replace(operand);
1086 if (new_operand) {
1087 dfs_stack.emplace_back(new_operand);
1088 }
1089 }
1090 }
1091 }
1092
1093 std::vector<std::unique_ptr<HloInstruction>> instructions;
1094 // First add the extra parameters to 'instructions'.
1095 for (const auto& instr : extra_parameters) {
1096 CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1097 << "Only parameter instructions are allowed in 'extra_parameters'";
1098 instructions.emplace_back(instr->Clone());
1099 }
1100 for (auto instr : postorder) {
1101 std::vector<HloInstruction*> new_operands;
1102 for (auto operand : instr->operands()) {
1103 auto replaced_operand = replace(operand);
1104 CHECK_NE(replaced_operand, nullptr)
1105 << "replacements map tried to eliminate a used instruction "
1106 << operand->ToString() << ", used by " << instr->ToString();
1107 new_operands.push_back(context->GetInstruction(replaced_operand));
1108 }
1109 std::unique_ptr<HloInstruction> new_instr =
1110 instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1111 if (instr->opcode() == HloOpcode::kParameter &&
1112 instr->parameter_replicated_at_leaf_buffers().has_value()) {
1113 new_instr->set_parameter_replicated_at_leaf_buffers(
1114 instr->parameter_replicated_at_leaf_buffers().value());
1115 }
1116 instructions.push_back(std::move(new_instr));
1117 }
1118 Builder builder(name() + "." + suffix);
1119 for (auto& instr : instructions) {
1120 builder.AddInstruction(std::move(instr));
1121 }
1122 auto result = builder.Build(
1123 /*root_instruction=*/context->GetInstruction(replace(new_root)));
1124
1125 // Clone control dependencies.
1126 for (auto instr : postorder) {
1127 HloInstruction* new_instr = context->GetInstruction(instr);
1128 for (auto successor : instr->control_successors()) {
1129 auto replaced_successor = replace(successor);
1130 // successor may not have been remapped, because it might have been
1131 // removed by the replacements map.
1132 if (replaced_successor != nullptr) {
1133 TF_CHECK_OK(new_instr->AddControlDependencyTo(
1134 context->GetInstruction(replaced_successor)));
1135 }
1136 }
1137 }
1138 context->MapComputation(this, result.get());
1139 return result;
1140 }
1141
UniquifyName(NameUniquer * name_uniquer)1142 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1143 name_ = name_uniquer->GetUniqueName(name_);
1144 }
1145
GetInstructionWithName(absl::string_view name)1146 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1147 auto instructions_in_computation = instructions();
1148 auto it = absl::c_find_if(
1149 instructions_in_computation,
1150 [&](HloInstruction* instr) { return instr->name() == name; });
1151 return it == instructions_in_computation.end() ? nullptr : *it;
1152 }
1153
IsEntryComputation() const1154 bool HloComputation::IsEntryComputation() const {
1155 return parent()->entry_computation() == this;
1156 }
1157 } // namespace xla
1158