1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48
49 using absl::StrCat;
50
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52 HloInstruction* root_instruction) {
53 int parameter_count = 0;
54 for (auto& instruction : instructions_) {
55 if (instruction->opcode() == HloOpcode::kParameter) {
56 parameter_count++;
57 }
58 }
59 // If root_instruction is not specified use the last added instruction.
60 HloInstruction* root =
61 root_instruction ? root_instruction : last_added_instruction_;
62 CHECK_NE(nullptr, root);
63 return absl::WrapUnique(new HloComputation(
64 name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68 const string& name, int parameter_count,
69 std::vector<std::unique_ptr<HloInstruction>>* instructions,
70 HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71 : name_(NameUniquer::GetSanitizedName(name)),
72 unique_id_(-1),
73 root_instruction_(root_instruction),
74 fusion_instruction_(fusion_instruction),
75 is_fusion_computation_(fusion_instruction != nullptr),
76 custom_call_instruction_(nullptr),
77 is_custom_call_computation_(false) {
78 param_instructions_.resize(parameter_count, nullptr);
79 bool root_found = false;
80 for (auto& instruction : *instructions) {
81 if (instruction->opcode() == HloOpcode::kParameter) {
82 int64_t param_no = instruction->parameter_number();
83 CHECK(param_no >= 0 && param_no < parameter_count)
84 << "\nERROR: invalid parameter number. Expected [0, "
85 << parameter_count << "), got " << param_no;
86 CHECK(param_instructions_[param_no] == nullptr)
87 << "\nERROR: parameter number " << param_no
88 << " already allocated in this computation";
89 param_instructions_[param_no] = instruction.get();
90 }
91 root_found |= instruction.get() == root_instruction_;
92 AddInstructionInternal(std::move(instruction));
93 }
94 CHECK(root_found)
95 << "\nERROR: root instruction is not present in computation.";
96 }
97
~HloComputation()98 HloComputation::~HloComputation() {
99 if (fusion_instruction_ != nullptr) {
100 CHECK(fusion_instruction_->fused_instructions_computation() == this);
101 fusion_instruction_->ClearCalledComputations();
102 fusion_instruction_ = nullptr;
103 }
104 }
105
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)106 HloInstruction* HloComputation::AddInstruction(
107 std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
108 CHECK(instruction->opcode() != HloOpcode::kParameter)
109 << "Parameter instructions cannot be added to a computation after "
110 << "it has been built";
111 if (!new_name.empty()) {
112 instruction->SetAndSanitizeName(new_name);
113 }
114 return AddInstructionInternal(std::move(instruction));
115 }
116
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)117 HloInstruction* HloComputation::AddInstructionInternal(
118 std::unique_ptr<HloInstruction> instruction) {
119 if (parent() != nullptr) {
120 instruction->UniquifyName(&parent()->instruction_name_uniquer());
121 instruction->SetUniqueId(parent()->NewUniqueInstructionId());
122 }
123 instruction->set_parent(this);
124 HloInstruction* pinst = instruction.get();
125 instruction_iterators_[pinst] =
126 instructions_.insert(instructions_.end(), std::move(instruction));
127 return pinst;
128 }
129
AddParameter(std::unique_ptr<HloInstruction> instruction)130 HloInstruction* HloComputation::AddParameter(
131 std::unique_ptr<HloInstruction> instruction) {
132 CHECK(instruction->opcode() == HloOpcode::kParameter);
133 CHECK(IsFusionComputation());
134 CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
135 instruction->set_parent(this);
136 param_instructions_.push_back(instruction.get());
137 AddInstructionInternal(std::move(instruction));
138 return instructions_.back().get();
139 }
140
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)141 HloInstruction* HloComputation::AddEntryComputationParameter(
142 std::unique_ptr<HloInstruction> instruction) {
143 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
144 CHECK_EQ(instruction->parameter_number(), num_parameters());
145 CHECK(parent()->entry_computation() == this);
146
147 HloModuleConfig config = parent()->config();
148 config.mutable_entry_computation_layout()->add_parameter_layout(
149 ShapeLayout(instruction->shape()));
150 parent()->set_config(config);
151
152 instruction->set_parent(this);
153 param_instructions_.push_back(instruction.get());
154 AddInstructionInternal(std::move(instruction));
155
156 return instructions_.back().get();
157 }
158
ReplaceEntryComputationParameter(int64_t param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)159 Status HloComputation::ReplaceEntryComputationParameter(
160 int64_t param_no, HloInstruction* old_instruction,
161 std::unique_ptr<HloInstruction> instruction) {
162 CHECK_GE(param_no, 0);
163 CHECK_LT(param_no, param_instructions_.size());
164 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
165 CHECK(parent()->entry_computation() == this);
166
167 HloModuleConfig config = parent()->config();
168 *config.mutable_entry_computation_layout()->mutable_parameter_layout(
169 param_no) = ShapeLayout(instruction->shape());
170 parent()->set_config(config);
171
172 instruction->set_parent(this);
173 param_instructions_[param_no] = instruction.get();
174 AddInstructionInternal(std::move(instruction));
175
176 return ForceRemoveInstruction(old_instruction);
177 }
178
RemoveParameter(int64_t param_no)179 Status HloComputation::RemoveParameter(int64_t param_no) {
180 CHECK_GE(param_no, 0);
181 CHECK_LT(param_no, param_instructions_.size());
182 CHECK(IsFusionComputation());
183 HloInstruction* param_instruction = param_instructions_[param_no];
184 auto param_instruction_iterator = param_instructions_.begin() + param_no;
185 param_instructions_.erase(param_instruction_iterator);
186 // Throw removed fused parameter instruction away.
187 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
188
189 while (param_no < param_instructions_.size()) {
190 param_instruction = param_instructions_[param_no];
191 HloInstruction* new_instr =
192 AddInstructionInternal(HloInstruction::CreateParameter(
193 param_no, param_instruction->shape(), StrCat("param_", param_no)));
194 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
195 param_instructions_[param_no] = new_instr;
196 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
197 param_no++;
198 }
199
200 return Status::OK();
201 }
202
RemoveUnusedParametersFromFusedComputation()203 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
204 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
205 }
206
RemoveUnusedParametersFromAnyComputation()207 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
208 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
209 }
210
RemoveUnusedParametersImpl(bool allow_non_fusion)211 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
212 CHECK(allow_non_fusion || IsFusionComputation());
213 int64_t removed = 0;
214 for (int64_t i = 0; i < param_instructions_.size(); ++i) {
215 HloInstruction* param_instruction = param_instructions_[i];
216 if (param_instruction->user_count() == 0 &&
217 param_instruction != root_instruction()) {
218 TF_RETURN_IF_ERROR(
219 RemoveInstructionImpl(param_instruction, allow_non_fusion));
220 ++removed;
221 continue;
222 }
223
224 if (removed > 0) {
225 const int64_t param_no = i - removed;
226 HloInstruction* new_instr = AddInstructionInternal(
227 HloInstruction::CreateParameter(param_no, param_instruction->shape(),
228 StrCat("param_", param_no)));
229 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
230 param_instructions_[param_no] = new_instr;
231 TF_RETURN_IF_ERROR(
232 RemoveInstructionImpl(param_instruction, allow_non_fusion));
233 }
234 }
235 param_instructions_.resize(param_instructions_.size() - removed);
236 return Status::OK();
237 }
238
IsSafelyRemovable(const HloInstruction * instruction)239 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
240 // If the instruction has control predecessors or successors then we cannot
241 // remove the instruction without violating ordering constraints (added, for
242 // example, to avert interference due to buffer aliasing).
243 if (!instruction->control_predecessors().empty() ||
244 !instruction->control_successors().empty()) {
245 return false;
246 }
247
248 if (instruction->opcode() == HloOpcode::kParameter &&
249 !IsFusionComputation()) {
250 return false;
251 }
252
253 return true;
254 }
255
HasSideEffect() const256 bool HloComputation::HasSideEffect() const {
257 for (auto* instruction : instructions()) {
258 if (instruction->HasSideEffect()) {
259 return true;
260 }
261 }
262 return false;
263 }
264
IsMarkedAsDead(const HloInstruction * inst)265 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
266 return inst->IsMarkedAsDead();
267 }
268
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)269 Status HloComputation::RemoveInstructionAndUnusedOperands(
270 HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
271 TF_RET_CHECK(root_instruction() != instruction);
272
273 TF_RET_CHECK(instruction->user_count() == 0);
274 TF_RET_CHECK(IsSafelyRemovable(instruction))
275 << "Cannot remove instruction: " << instruction->ToString();
276 absl::flat_hash_set<HloInstruction*> removed;
277 std::queue<HloInstruction*> worklist;
278 worklist.push(instruction);
279 while (!worklist.empty()) {
280 HloInstruction* item = worklist.front();
281 worklist.pop();
282
283 if (removed.contains(item) || item->user_count() != 0 ||
284 item == root_instruction() || !IsSafelyRemovable(item) ||
285 (item->HasSideEffect() && item != instruction)) {
286 continue;
287 }
288 for (int i = 0; i < item->operand_count(); ++i) {
289 worklist.push(item->mutable_operand(i));
290 }
291
292 if (cleanup) {
293 cleanup(item);
294 }
295 TF_RETURN_IF_ERROR(RemoveInstruction(item));
296 removed.insert(item);
297 }
298 return Status::OK();
299 }
300
RemoveInstruction(HloInstruction * instruction)301 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
302 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
303 }
304
ForceRemoveInstruction(HloInstruction * instruction)305 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
306 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
307 }
308
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)309 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
310 bool ignore_safety_check) {
311 VLOG(2) << "Removing instruction " << instruction->name()
312 << " from computation " << name();
313 TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
314 << "cannot remove instruction: " << instruction->ToString();
315 TF_RET_CHECK(root_instruction() != instruction)
316 << "cannot remove root instruction " << instruction->name();
317 TF_RET_CHECK(instruction->user_count() == 0)
318 << "instruction " << instruction->name()
319 << " has users and cannot be removed";
320 TF_RET_CHECK(instruction->control_predecessors().empty())
321 << "instruction " << instruction->name()
322 << " has control predecessors and cannot be removed";
323 TF_RET_CHECK(instruction->control_successors().empty())
324 << "instruction " << instruction->name()
325 << " has control successors and cannot be removed";
326
327 auto inst_it = instruction_iterators_.find(instruction);
328 TF_RET_CHECK(inst_it != instruction_iterators_.end());
329 (*inst_it->second)->set_parent(nullptr);
330 to_be_deleted_.emplace_back(inst_it->second->release());
331 to_be_deleted_.back()->DetachFromOperandsAndUsers();
332 // Clear all operands to avoid Null operands.
333 to_be_deleted_.back()->RemoveAllOperands();
334 to_be_deleted_.back()->ClearCalledComputations();
335 to_be_deleted_.back()->MarkAsDead();
336 instructions_.erase(inst_it->second);
337 instruction_iterators_.erase(inst_it);
338 return Status::OK();
339 }
340
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)341 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
342 bool accept_different_shape) {
343 // The shape of the root (ignoring layout) is an invariant of the computation
344 // for non-fusion cases.
345 if (!IsFusionComputation() && !accept_different_shape) {
346 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
347 root_instruction_->shape()))
348 << new_root_instruction->shape() << " is incompatible with "
349 << root_instruction_->shape();
350 }
351 bool root_found = false;
352 for (auto& instruction : instructions_) {
353 if (new_root_instruction == instruction.get()) {
354 root_found = true;
355 break;
356 }
357 }
358 DCHECK(root_found);
359
360 if (parent() && parent()->has_entry_computation() &&
361 parent()->entry_computation() == this) {
362 if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
363 root_instruction_->shape())) {
364 // Rebuild input output alias config now that we have a new output shape.
365 parent()->input_output_alias_config() =
366 HloInputOutputAliasConfig(new_root_instruction->shape());
367 }
368 }
369
370 root_instruction_ = new_root_instruction;
371 }
372
373 namespace {
374
375 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)376 void ComputeComputationPostOrder(HloComputation* computation,
377 absl::flat_hash_set<HloComputation*>* visited,
378 std::vector<HloComputation*>* post_order) {
379 if (visited->insert(computation).second) {
380 for (auto* instruction : computation->instructions()) {
381 for (HloComputation* called_computation :
382 instruction->called_computations()) {
383 ComputeComputationPostOrder(called_computation, visited, post_order);
384 }
385 }
386 post_order->push_back(computation);
387 }
388 }
389
390 } // namespace
391
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const392 void HloComputation::ComputeInstructionPostOrder(
393 const HloComputation::ChannelDependencyGroup& channel_dependency_group,
394 std::vector<HloInstruction*>* post_order, HloInstruction* root,
395 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
396 std::vector<HloInstruction*> dfs_stack;
397 dfs_stack.push_back(root);
398 while (!dfs_stack.empty()) {
399 const auto current = dfs_stack.back();
400 CHECK_EQ(current->parent(), this)
401 << "Instruction " << current->name()
402 << " is not in the current computation (" << name() << ").";
403 auto it = visited->find(current);
404 if (it != visited->end()) {
405 if (it->second == kVisited) {
406 // Already visited.
407 dfs_stack.pop_back();
408 continue;
409 }
410 // Visit this node.
411 CHECK_EQ(kVisiting, it->second);
412 dfs_stack.pop_back();
413 post_order->push_back(current);
414 it->second = kVisited;
415 continue;
416 }
417
418 visited->insert({current, kVisiting});
419
420 const auto get_channel_id =
421 [](HloInstruction* inst) -> absl::optional<int64> {
422 switch (inst->opcode()) {
423 case HloOpcode::kRecvDone:
424 case HloOpcode::kAllReduce:
425 case HloOpcode::kAllGather:
426 case HloOpcode::kAllToAll:
427 case HloOpcode::kReduceScatter:
428 return inst->channel_id();
429 default:
430 return absl::nullopt;
431 }
432 };
433
434 // When adding a predecessor to the dfs_stack, we need to also add its
435 // associated channel dependencies.
436 const auto add_dfs_stack = [&](HloInstruction* inst) {
437 auto channel_id = get_channel_id(inst);
438 if (channel_id && channel_dependency_group.count(*channel_id)) {
439 auto it = channel_dependency_group.find(*channel_id);
440 for (HloInstruction* cinst : it->second) {
441 dfs_stack.emplace_back(cinst);
442 }
443 } else {
444 dfs_stack.emplace_back(inst);
445 }
446 };
447
448 const auto add_predecessors = [&](HloInstruction* inst) {
449 // Add the operands to the stack in reverse order so the first operand is
450 // processed first. This will produce a more natural ordering and a nicer
451 // result for things like HLO stringification.
452 const auto& operands = inst->operands();
453 for (int64_t i = operands.size() - 1; i >= 0; --i) {
454 add_dfs_stack(operands[i]);
455 }
456
457 for (HloInstruction* op : inst->control_predecessors()) {
458 add_dfs_stack(op);
459 }
460 };
461
462 // If the current instruction is a channel instruction, add the dependencies
463 // from all associated instructions of the channel.
464 auto channel_id = get_channel_id(current);
465 if (channel_id && channel_dependency_group.count(*channel_id)) {
466 auto it = channel_dependency_group.find(*channel_id);
467 for (HloInstruction* cinst : it->second) {
468 add_predecessors(cinst);
469 }
470 } else {
471 add_predecessors(current);
472 }
473 }
474 }
475
476 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const477 HloComputation::ComputeChannelDependencies() const {
478 ChannelDependencyGroup channel_dependency_group;
479 if (parent() && parent()->config().has_static_device_assignment() &&
480 (parent()->config().static_device_assignment().computation_count() == 1 ||
481 parent()->config().use_spmd_partitioning())) {
482 return channel_dependency_group;
483 }
484 for (const auto& instruction : instructions_) {
485 switch (instruction->opcode()) {
486 case HloOpcode::kSend:
487 case HloOpcode::kRecvDone:
488 case HloOpcode::kAllReduce:
489 case HloOpcode::kAllGather:
490 case HloOpcode::kAllToAll:
491 case HloOpcode::kReduceScatter: {
492 auto channel_id = instruction->channel_id();
493 if (channel_id) {
494 channel_dependency_group[channel_id.value()].push_back(
495 instruction.get());
496 }
497 break;
498 }
499 default:
500 break;
501 }
502 }
503 return channel_dependency_group;
504 }
505
HasOnlyTraceUsers(const HloInstruction * instruction)506 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
507 return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
508 return user->opcode() == HloOpcode::kTrace;
509 });
510 }
511
MakeInstructionPostOrder() const512 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
513 auto channel_dependency_group = ComputeChannelDependencies();
514 std::vector<HloInstruction*> post_order;
515 post_order.reserve(instruction_count());
516 std::vector<HloInstruction*> trace_instructions;
517 absl::flat_hash_map<HloInstruction*, VisitState> visited;
518 visited.reserve(instruction_count());
519 for (auto& instruction : instructions_) {
520 if (instruction->opcode() == HloOpcode::kTrace) {
521 // Trace instructions aren't handled by the DFS visitor. Add trace
522 // instructions to the post order at the end (necessarily they have no
523 // users).
524 trace_instructions.push_back(instruction.get());
525 } else if (HasOnlyTraceUsers(instruction.get())) {
526 ComputeInstructionPostOrder(channel_dependency_group, &post_order,
527 instruction.get(), &visited);
528 }
529 }
530 post_order.insert(post_order.end(), trace_instructions.begin(),
531 trace_instructions.end());
532 CHECK_EQ(instructions_.size(), post_order.size())
533 << "number of instructions does not match post order size";
534 return post_order;
535 }
536
MakeEmbeddedComputationsList() const537 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
538 const {
539 absl::flat_hash_set<HloComputation*> visited;
540 std::vector<HloComputation*> post_order;
541
542 // To avoid special handling of this computation, cast away const of
543 // 'this'. 'this' is immediately removed from the post order after
544 // construction.
545 //
546 // TODO(b/78350259): This violates const-correctness, since while the original
547 // computation is not returned, we still retrieve non-const computations from
548 // a const one. Consider also avoiding const for HloComputation, or review XLA
549 // for const-correctness of non-HloInstruction* types like this.
550 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
551 &post_order);
552
553 // We don't want to include this computation in the post order.
554 CHECK_EQ(this, post_order.back());
555 post_order.pop_back();
556
557 return post_order;
558 }
559
ToString(const HloPrintOptions & options) const560 string HloComputation::ToString(const HloPrintOptions& options) const {
561 return ToString(options, MakeInstructionPostOrder());
562 }
563
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const564 string HloComputation::ToString(
565 const HloPrintOptions& options,
566 absl::Span<const HloInstruction* const> instruction_order) const {
567 CHECK_EQ(instruction_order.size(), instruction_count());
568
569 const string tab(2 * options.indent_amount(), ' ');
570
571 std::ostringstream s;
572 s << tab;
573
574 if (!options.is_in_nested_computation()) {
575 if (options.print_percent()) {
576 s << "%";
577 }
578 if (options.print_ids()) {
579 // Exclude entry computation's name because it includes and leads to
580 // non-deterministic fingerprint.
581 s << PrintName(name(), options.print_ids()) << " ";
582 }
583 }
584
585 if (options.print_program_shape()) {
586 s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
587 << " ";
588 }
589 s << "{\n";
590
591 // There are instructions which are required to be printed. Additionally, we
592 // print some instructions before and after required ones. The resulting
593 // output has the following format.
594 //
595 // computation {
596 // ...
597 // additional_instructions
598 // required_instructions
599 // additional_instructions
600 // ...
601 // additional_instructions
602 // required_instructions
603 // additional_instructions
604 // ...
605 // }
606 std::set<int> instructions_to_print;
607 {
608 // Find all the instructions that should be printed.
609 auto add_instruction = [&instructions_to_print,
610 &instruction_order](int index) {
611 if (index < 0 || index >= instruction_order.size()) {
612 return;
613 }
614 instructions_to_print.insert(index);
615 };
616
617 auto add_instructions_arround = [&add_instruction, &options](int index) {
618 for (int i = index - options.leading_and_trailing_instructions_number();
619 i <= index + options.leading_and_trailing_instructions_number();
620 ++i) {
621 add_instruction(i);
622 }
623 };
624
625 for (int i = 0; i < instruction_order.size(); ++i) {
626 const HloInstruction* instruction = instruction_order[i];
627 CHECK_EQ(this, instruction->parent());
628 if (options.print_instruction(instruction)) {
629 add_instructions_arround(i);
630 }
631 }
632 }
633
634 {
635 // Print the instructions in this computation.
636 HloPrintOptions new_options = options;
637 new_options.set_indent_amount(options.indent_amount() + 1)
638 .set_is_in_nested_computation(true);
639
640 const string new_tab(2 * new_options.indent_amount(), ' ');
641
642 CanonicalNameMap name_map;
643
644 bool print_prev = true;
645 for (int index = 0; index < instruction_order.size(); ++index) {
646 const HloInstruction* instruction = instruction_order[index];
647 if (instructions_to_print.find(index) != instructions_to_print.end()) {
648 s << new_options.format_instruction(
649 instruction,
650 instruction->ToStringWithCanonicalNameMap(new_options,
651 &name_map),
652 new_options.indent_amount(), instruction == root_instruction_)
653 << "\n";
654 print_prev = true;
655 } else if (print_prev) {
656 s << new_tab << "...\n";
657 print_prev = false;
658 }
659 }
660 }
661
662 s << tab << "}";
663 return s.str();
664 }
665
ToProto() const666 HloComputationProto HloComputation::ToProto() const {
667 HloComputationProto proto;
668 CHECK(unique_id_ != -1)
669 << "This computation does not have a valid id. Please make sure the "
670 "computation is inside a module before dumping it.";
671 proto.set_id(unique_id_);
672 proto.set_name(name_);
673 for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
674 HloInstructionProto instruction_proto = instruction->ToProto();
675 proto.add_instructions()->Swap(&instruction_proto);
676 }
677 proto.set_root_id(root_instruction()->unique_id());
678 *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
679 return proto;
680 }
681
682 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)683 HloComputation::CreateFromProto(
684 const HloComputationProto& proto,
685 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
686 bool prohibit_empty_literal) {
687 absl::flat_hash_map<int64, HloInstruction*> instruction_map;
688 absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
689 std::vector<std::unique_ptr<HloInstruction>> instructions;
690 int64_t parameter_count = 0;
691 for (const HloInstructionProto& instruction_proto : proto.instructions()) {
692 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
693 HloInstruction::CreateFromProto(
694 instruction_proto, instruction_map, computation_map,
695 prohibit_empty_literal));
696 if (instruction->opcode() == HloOpcode::kParameter) {
697 parameter_count++;
698 }
699 TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
700 instruction_map[instruction_proto.id()] = instruction.get();
701 to_proto_id[instruction.get()] = instruction_proto.id();
702 instructions.push_back(std::move(instruction));
703 }
704
705 TF_RET_CHECK(proto.root_id() != -1);
706 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
707 HloInstruction* root = instruction_map.at(proto.root_id());
708
709 // Sort the instructions in the proto id's order.
710 absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
711 const std::unique_ptr<HloInstruction>& b) {
712 return to_proto_id[a.get()] < to_proto_id[b.get()];
713 });
714
715 TF_RETURN_IF_ERROR([&]() -> Status {
716 std::vector<bool> parameters_seen(parameter_count);
717 int parameters_seen_count = 0;
718 for (auto& instruction : instructions) {
719 if (instruction->opcode() == HloOpcode::kParameter) {
720 int64_t param_no = instruction->parameter_number();
721 TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
722 << "Invalid parameter number. Expected [0, " << parameter_count
723 << "), got " << param_no;
724 TF_RET_CHECK(!parameters_seen[param_no])
725 << "Parameter number " << param_no
726 << " already allocated in this computation";
727 parameters_seen[param_no] = true;
728 parameters_seen_count++;
729 }
730 }
731 TF_RET_CHECK(parameters_seen_count == parameter_count)
732 << "Not all parameters in range [0, " << parameter_count
733 << ") were referenced";
734 return Status::OK();
735 }());
736
737 auto computation = absl::WrapUnique(
738 new HloComputation(proto.name(), parameter_count, &instructions, root,
739 /*fusion_instruction=*/nullptr));
740 computation->unique_id_ = proto.id();
741 return std::move(computation);
742 }
743
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)744 void HloComputation::FuseInstructionsInto(
745 absl::Span<HloInstruction* const> instructions_to_fuse,
746 HloInstruction* fusion_instruction) {
747 CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
748 HloInstruction* root = instructions_to_fuse.front();
749 TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
750 if (root == root_instruction()) {
751 set_root_instruction(fusion_instruction);
752 }
753 TF_CHECK_OK(RemoveInstruction(root));
754 for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
755 HloInstruction* instruction = instructions_to_fuse[i];
756 fusion_instruction->FuseInstruction(instruction);
757 if (instruction->user_count() == 0) {
758 TF_CHECK_OK(RemoveInstruction(instruction));
759 }
760 }
761 }
762
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)763 HloInstruction* HloComputation::CreateFusionInstruction(
764 absl::Span<HloInstruction* const> instructions_to_fuse,
765 HloInstruction::FusionKind fusion_kind) {
766 HloInstruction* root = instructions_to_fuse.front();
767 HloInstruction* fusion_instruction = AddInstruction(
768 HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
769 FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
770 return fusion_instruction;
771 }
772
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)773 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
774 HloInstruction* instruction, ShapeIndex* index,
775 const std::function<
776 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
777 HloComputation* computation)>& copy_leaf) {
778 if (instruction->shape().IsTuple()) {
779 std::vector<HloInstruction*> elements;
780 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
781 i++) {
782 HloInstruction* gte =
783 AddInstruction(HloInstruction::CreateGetTupleElement(
784 ShapeUtil::GetTupleElementShape(instruction->shape(), i),
785 instruction, i));
786
787 index->push_back(i);
788 TF_ASSIGN_OR_RETURN(HloInstruction * element,
789 DeepCopyHelper(gte, index, copy_leaf));
790 elements.push_back(element);
791 index->pop_back();
792 }
793 return AddInstruction(HloInstruction::CreateTuple(elements));
794 }
795 if (instruction->shape().IsToken()) {
796 // Tokens have no on-device representation and cannot be copied. Pass
797 // through transparently.
798 return instruction;
799 }
800
801 // Array shape.
802 TF_RET_CHECK(instruction->shape().IsArray());
803 return copy_leaf(instruction, *index, this);
804 }
805
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)806 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
807 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
808 ShapeTree<HloInstruction*>* copies_added) {
809 if (instruction->parent() != this) {
810 return FailedPrecondition(
811 "Can't deep copy instruction %s: instruction is not in computation %s",
812 instruction->name(), name());
813 }
814 if (indices_to_copy != nullptr &&
815 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
816 return FailedPrecondition(
817 "Can't deep copy instruction %s: given shape tree of indices to copy "
818 "has incompatible shapes: %s vs. %s",
819 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
820 ShapeUtil::HumanString(indices_to_copy->shape()));
821 }
822
823 ShapeIndex index;
824 auto copy_leaf = [indices_to_copy, copies_added](
825 HloInstruction* leaf, const ShapeIndex& leaf_index,
826 HloComputation* computation) {
827 if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
828 HloInstruction* copy = computation->AddInstruction(
829 HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
830 if (copies_added != nullptr) {
831 *copies_added->mutable_element(leaf_index) = copy;
832 }
833 return copy;
834 }
835 // Elements which are not to be copied are passed through
836 // transparently.
837 return leaf;
838 };
839 return DeepCopyHelper(instruction, &index, copy_leaf);
840 }
841
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)842 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
843 HloInstruction* instruction,
844 const std::function<
845 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
846 HloComputation* computation)>& copy_leaf) {
847 if (instruction->parent() != this) {
848 return FailedPrecondition(
849 "Can't deep copy instruction %s: instruction is not in computation %s",
850 instruction->name(), name());
851 }
852 ShapeIndex index;
853 return DeepCopyHelper(instruction, &index, copy_leaf);
854 }
855
ComputeProgramShape(bool include_ids) const856 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
857 ProgramShape program_shape;
858
859 for (auto* param_instruction : param_instructions_) {
860 *program_shape.add_parameters() = param_instruction->shape();
861 *program_shape.add_parameter_names() =
862 PrintName(param_instruction->name(), include_ids);
863 }
864 *program_shape.mutable_result() = root_instruction_->shape();
865
866 return program_shape;
867 }
868
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values) const869 bool HloComputation::EqualInternal(const HloComputation& other,
870 bool is_layout_sensitive,
871 bool ignore_channel_id_values) const {
872 if (this == &other) {
873 return true;
874 }
875 absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
876 visited;
877 std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
878
879 worklist.push_back({root_instruction(), other.root_instruction()});
880
881 while (!worklist.empty()) {
882 auto pair = worklist.back();
883 worklist.pop_back();
884
885 if (visited.contains(pair)) {
886 continue;
887 }
888 visited.emplace(pair);
889 // TODO(b/123082518): Avoid recursively invoking Equal because it may
890 // cause a stack overflow with deeply nested subcomputations.
891 auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
892 return true;
893 };
894 auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
895 return a->EqualInternal(*b, is_layout_sensitive,
896 ignore_channel_id_values);
897 };
898 bool identical_ignoring_operands =
899 ignore_channel_id_values
900 ? pair.first->IdenticalIgnoringChannelIdValues(
901 *pair.second, operands_eq, comp_eq, is_layout_sensitive)
902 : pair.first->Identical(*pair.second, operands_eq, comp_eq,
903 is_layout_sensitive);
904 if (!identical_ignoring_operands) {
905 return false;
906 }
907 for (size_t i = 0; i < pair.first->operands().size(); ++i) {
908 worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
909 }
910 }
911 return true;
912 }
913
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)914 Status HloComputation::ReplaceWithNewInstruction(
915 HloInstruction* old_instruction,
916 std::unique_ptr<HloInstruction> new_instruction) {
917 return ReplaceInstruction(old_instruction,
918 AddInstruction(std::move(new_instruction)));
919 }
920
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)921 Status HloComputation::ReplaceWithNewEntryComputationParameter(
922 HloInstruction* old_instruction,
923 std::unique_ptr<HloInstruction> new_instruction) {
924 return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
925 std::move(new_instruction)));
926 }
927
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)928 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
929 HloInstruction* new_instruction) {
930 TF_RET_CHECK(
931 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
932 << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
933 << ShapeUtil::HumanString(new_instruction->shape());
934
935 VLOG(10) << "transformed " << old_instruction->ToString() << " to "
936 << new_instruction->ToString();
937 // Try to add metadata for HLO instructions that are created to replace
938 // existing HLO instructions (e.g. during optimizations). The assumption is
939 // that the old instruction and the new instruction would perform the same
940 // function, and that they would be correlated to the same TF op. This might
941 // not always be correct since HLO optimizations can cross TF op boundaries.
942 // But still this seems to be better than nothing.
943 bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
944 !old_instruction->metadata().op_name().empty();
945 bool overwrite_pass_id =
946 new_instruction->metadata().op_name().empty() &&
947 new_instruction->metadata().logical_creation_pass_id() == 0 &&
948 old_instruction->metadata().logical_creation_pass_id() != 0;
949 if (overwrite_op_name || overwrite_pass_id) {
950 new_instruction->set_metadata(old_instruction->metadata());
951 }
952 if (new_instruction->frontend_attributes().map().empty()) {
953 new_instruction->set_frontend_attributes(
954 old_instruction->frontend_attributes());
955 }
956
957 // Like the metadata above, if the user didn't specify any sharding
958 // information on the new instruction we should copy the old sharding
959 // information (if any).
960 if (!new_instruction->has_sharding()) {
961 new_instruction->set_sharding(old_instruction->sharding_ptr());
962 }
963
964 TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
965 return RemoveInstructionAndUnusedOperands(old_instruction);
966 }
967
CollectUnreachableRoots() const968 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
969 std::vector<HloInstruction*> unreachable_roots;
970 for (auto* instruction : instructions()) {
971 if (instruction->user_count() == 0 &&
972 instruction->control_successors().empty() &&
973 instruction != root_instruction()) {
974 unreachable_roots.push_back(instruction);
975 }
976 }
977 VLOG(3) << "Unreachable roots:"
978 << absl::StrJoin(unreachable_roots, "\n\t",
979 [](string* out, const HloInstruction* hlo) {
980 absl::StrAppend(out, hlo->ToString());
981 });
982 return unreachable_roots;
983 }
984
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const985 Status HloComputation::AcceptWithOperandOrder(
986 DfsHloVisitor* visitor,
987 const HloInstruction::CompareFunction& operand_order) const {
988 // Visit unreachable roots. Beware that the visitor might delete the currently
989 // visited root, which would invalidate iterators if the unreachable roots
990 // weren't computed ahead of time.
991 for (HloInstruction* root : CollectUnreachableRoots()) {
992 TF_RETURN_IF_ERROR(
993 root->AcceptWithOperandOrder(visitor, operand_order,
994 /*call_finish_visit=*/false));
995 }
996 // Visit the computation root instruction last.
997 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
998 /*call_finish_visit=*/true);
999 }
1000
Clone(const string & suffix,HloCloneContext * context)1001 std::unique_ptr<HloComputation> HloComputation::Clone(
1002 const string& suffix, HloCloneContext* context) {
1003 return CloneWithReplacements(
1004 /*replacements=*/absl::flat_hash_map<const HloInstruction*,
1005 std::unique_ptr<HloInstruction>>(),
1006 /*extra_parameters=*/{}, context, suffix);
1007 }
1008
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)1009 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1010 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1011 HloCloneContext* context, const string& suffix) {
1012 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1013 replacements;
1014 replacements.emplace(std::move(r1));
1015 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1016 context, suffix);
1017 }
1018
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)1019 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1020 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1021 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1022 HloCloneContext* context, const string& suffix) {
1023 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1024 replacements;
1025 replacements.emplace(std::move(r1));
1026 replacements.emplace(std::move(r2));
1027 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1028 context, suffix);
1029 }
1030
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)1031 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1032 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1033 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1034 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1035 HloCloneContext* context, const string& suffix) {
1036 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1037 replacements;
1038 replacements.emplace(std::move(r1));
1039 replacements.emplace(std::move(r2));
1040 replacements.emplace(std::move(r3));
1041 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1042 context, suffix);
1043 }
1044
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix,const HloInstruction * new_root)1045 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1046 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1047 replacements,
1048 absl::Span<const HloInstruction* const> extra_parameters,
1049 HloCloneContext* context, const string& suffix,
1050 const HloInstruction* new_root) {
1051 std::unique_ptr<HloCloneContext> context_ptr;
1052 if (context == nullptr) {
1053 context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
1054 context = context_ptr.get();
1055 }
1056 if (new_root == nullptr) {
1057 new_root = root_instruction();
1058 }
1059
1060 // Look up instr in the replacements map, and return either the replacement,
1061 // or instr, if the replacement isn't present.
1062 //
1063 // Note: This can return null, indicating that instr should not be present in
1064 // the new computation.
1065 auto replace = [&](const HloInstruction* instr) {
1066 auto it = replacements.find(instr);
1067 return it != replacements.end() ? it->second.get() : instr;
1068 };
1069
1070 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1071
1072 // We want to do a postorder walk over [replace(i) for i in instructions_].
1073 // We can't reuse MakeInstructionPostOrder() for this, because that will
1074 // generate a postorder of plain instructions_, and our replacements may
1075 // change the postorder!
1076 //
1077 // The postorder we want here is simpler than what MakeInstructionPostOrder()
1078 // does -- we only care about operand dependencies -- so let's just do it
1079 // ourselves.
1080 std::vector<const HloInstruction*> postorder;
1081 absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1082 for (const auto& instr : instructions_) {
1083 std::vector<const HloInstruction*> dfs_stack;
1084 const HloInstruction* new_instr = replace(instr.get());
1085 if (!new_instr) {
1086 continue;
1087 }
1088 dfs_stack.push_back(new_instr);
1089
1090 while (!dfs_stack.empty()) {
1091 auto* cur = dfs_stack.back();
1092 auto it = visited.find(cur);
1093 if (it != visited.end()) {
1094 dfs_stack.pop_back();
1095 if (it->second == kVisited) {
1096 continue;
1097 }
1098 CHECK_EQ(it->second, kVisiting);
1099 postorder.push_back(cur);
1100 it->second = kVisited;
1101 continue;
1102 }
1103
1104 visited.insert({cur, kVisiting});
1105 for (HloInstruction* operand : cur->operands()) {
1106 const HloInstruction* new_operand = replace(operand);
1107 if (new_operand) {
1108 dfs_stack.emplace_back(new_operand);
1109 }
1110 }
1111 }
1112 }
1113
1114 std::vector<std::unique_ptr<HloInstruction>> instructions;
1115 // First add the extra parameters to 'instructions'.
1116 for (const auto& instr : extra_parameters) {
1117 CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1118 << "Only parameter instructions are allowed in 'extra_parameters'";
1119 instructions.emplace_back(instr->Clone());
1120 }
1121 for (auto instr : postorder) {
1122 std::vector<HloInstruction*> new_operands;
1123 for (auto operand : instr->operands()) {
1124 auto replaced_operand = replace(operand);
1125 CHECK_NE(replaced_operand, nullptr)
1126 << "replacements map tried to eliminate a used instruction "
1127 << operand->ToString() << ", used by " << instr->ToString();
1128 new_operands.push_back(context->GetInstruction(replaced_operand));
1129 }
1130 std::unique_ptr<HloInstruction> new_instr =
1131 instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1132 if (instr->opcode() == HloOpcode::kParameter &&
1133 instr->parameter_replicated_at_leaf_buffers().has_value()) {
1134 new_instr->set_parameter_replicated_at_leaf_buffers(
1135 instr->parameter_replicated_at_leaf_buffers().value());
1136 }
1137 instructions.push_back(std::move(new_instr));
1138 }
1139 Builder builder(name() + "." + suffix);
1140 for (auto& instr : instructions) {
1141 builder.AddInstruction(std::move(instr));
1142 }
1143 auto result = builder.Build(
1144 /*root_instruction=*/context->GetInstruction(replace(new_root)));
1145
1146 // Clone control dependencies.
1147 for (auto instr : postorder) {
1148 HloInstruction* new_instr = context->GetInstruction(instr);
1149 for (auto successor : instr->control_successors()) {
1150 auto replaced_successor = replace(successor);
1151 // successor may not have been remapped, because it might have been
1152 // removed by the replacements map.
1153 if (replaced_successor != nullptr) {
1154 TF_CHECK_OK(new_instr->AddControlDependencyTo(
1155 context->GetInstruction(replaced_successor)));
1156 }
1157 }
1158 }
1159 context->MapComputation(this, result.get());
1160 return result;
1161 }
1162
UniquifyName(NameUniquer * name_uniquer)1163 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1164 name_ = name_uniquer->GetUniqueName(name_);
1165 }
1166
GetInstructionWithName(absl::string_view name)1167 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1168 auto instructions_in_computation = instructions();
1169 auto it = absl::c_find_if(
1170 instructions_in_computation,
1171 [&](HloInstruction* instr) { return instr->name() == name; });
1172 return it == instructions_in_computation.end() ? nullptr : *it;
1173 }
1174
IsEntryComputation() const1175 bool HloComputation::IsEntryComputation() const {
1176 return parent()->entry_computation() == this;
1177 }
1178 } // namespace xla
1179